Commit 55acb151 authored by Ruilong Li's avatar Ruilong Li
Browse files

Revert "rename to "proposal_resampling""

This reverts commit 3090d7bd.
parent 3090d7bd
......@@ -2,7 +2,6 @@
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import math
from typing import Callable, List, Union
import torch
......@@ -74,11 +73,8 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
hidden_dim: int = 64,
geo_feat_dim: int = 15,
n_levels: int = 16,
max_res: int = 1024,
base_res: int = 16,
log2_hashmap_size: int = 19,
) -> None:
super().__init__()
......@@ -91,9 +87,7 @@ class NGPradianceField(torch.nn.Module):
self.unbounded = unbounded
self.geo_feat_dim = geo_feat_dim
per_level_scale = math.exp(
(math.log(max_res) - math.log(base_res)) / (n_levels - 1)
)
per_level_scale = 1.4472692012786865
if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding(
......@@ -111,39 +105,25 @@ class NGPradianceField(torch.nn.Module):
},
)
if hidden_dim > 0:
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_res,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": hidden_dim,
"n_hidden_layers": 1,
},
)
else:
self.mlp_base = tcnn.Encoding(
n_input_dims=num_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": 1,
"n_features_per_level": 1,
"log2_hashmap_size": 21,
"base_resolution": 128,
"per_level_scale": 1.0,
},
)
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 1,
},
)
if self.geo_feat_dim > 0:
self.mlp_head = tcnn.Network(
n_input_dims=(
......
......@@ -37,7 +37,7 @@ def ray_resampling(
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous().int(),
packed_info.contiguous(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
......
......@@ -23,9 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching")
ray_marching_with_grid = _make_lazy_cuda_func("ray_marching_with_grid")
ray_resampling = _make_lazy_cuda_func("ray_resampling")
ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
......
......@@ -4,97 +4,13 @@
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void pdf_query_kernel(
const uint32_t n_rays,
// query
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *pdfs, // pdf to be queried
// resample
const int *resample_packed_info, // input ray & point indices.
const scalar_t *resample_starts, // input start t, sorted
const scalar_t *resample_ends, // input end t, sorted
// output
scalar_t *resample_pdfs) // should be zero-initialized
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0]; // point idx start.
const int resample_steps = resample_packed_info[i * 2 + 1]; // point idx shift.
if (resample_steps == 0) // nothing to query
return;
if (steps == 0) // nothing to be queried: set pdfs to 0
return;
starts += base;
ends += base;
pdfs += base;
resample_starts += resample_base;
resample_ends += resample_base;
resample_pdfs += resample_base;
// which interval is resample_start (t0) located
int t0_id = -1;
scalar_t t0_start = 0.0f, t0_end = starts[0];
scalar_t cdf0_start = 0.0f, cdf0_end = 0.0f;
// which interval is resample_end (t1) located
int t1_id = -1;
scalar_t t1_start = 0.0f, t1_end = starts[0];
scalar_t cdf1_start = 0.0f, cdf1_end = 0.0f;
// go!
for (int j = 0; j < resample_steps; ++j)
{
scalar_t t0 = resample_starts[j];
while(t0 > t0_end & t0_id < steps - 1) {
t0_id++;
t0_start = starts[t0_id];
t0_end = ends[t0_id];
cdf0_start = cdf0_end;
cdf0_end += pdfs[t0_id];
}
if (t0 > t0_end) {
resample_pdfs[j] = 0.0f;
continue;
}
scalar_t pct0 = 0.0f; // max(t0 - t0_start, 0.0f) / max(t0_end - t0_start, 1e-10f);
scalar_t resample_cdf_start = cdf0_start + pct0 * (cdf0_end - cdf0_start);
scalar_t t1 = resample_ends[j];
while(t1 > t1_end & t1_id < steps - 1) {
t1_id++;
t1_start = starts[t1_id];
t1_end = ends[t1_id];
cdf1_start = cdf1_end;
cdf1_end += pdfs[t1_id];
}
if (t1 > t1_end) {
resample_pdfs[j] = cdf1_end - resample_cdf_start;
continue;
}
scalar_t pct1 = 1.0f; // max(t1 - t1_start, 0.0f) / max(t1_end - t1_start, 1e-10f);
scalar_t resample_cdf_end = cdf1_start + pct1 * (cdf1_end - cdf1_start);
// compute pdf of [t0, t1]
resample_pdfs[j] = resample_cdf_end - resample_cdf_start;
}
return;
}
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *w, // transmittance weights
const scalar_t *weights, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
......@@ -111,26 +27,25 @@ __global__ void cdf_resampling_kernel(
starts += base;
ends += base;
w += base;
weights += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t w_sum = 0.0f;
scalar_t weights_sum = 0.0f;
for (int j = 0; j < steps; j++)
w_sum += w[j];
// scalar_t padding = fmaxf(1e-10f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
int num_endpoints = resample_steps + 1;
scalar_t cdf_pad = 1.0f / (2 * num_endpoints);
scalar_t cdf_step_size = (1.0f - 2 * cdf_pad) / resample_steps;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = w[idx] / w_sum;
scalar_t cdf_u = cdf_pad;
while (j < num_endpoints)
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
{
if (cdf_u < cdf_next)
{
......@@ -138,32 +53,26 @@ __global__ void cdf_resampling_kernel(
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
// if (j == 100) {
// printf(
// "cdf_u: %.10f, cdf_next: %.10f, cdf_prev: %.10f, scaling: %.10f, t: %.10f, starts[idx]: %.10f, ends[idx]: %.10f\n",
// cdf_u, cdf_next, cdf_prev, scaling, t, starts[idx], ends[idx]);
// }
if (j < num_endpoints - 1)
if (j < num_bins - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
// cdf_u += cdf_step_size;
cdf_u += cdf_step_size;
j += 1;
cdf_u = j * cdf_step_size + cdf_pad;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += w[idx] / w_sum;
cdf_next += (weights[idx] + padding_step) / weights_sum;
}
}
// if (j != num_endpoints)
// {
// printf("Error: %d %d %f\n", j, num_endpoints, weights_sum);
// }
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
return;
}
......@@ -256,7 +165,7 @@ std::vector<torch::Tensor> ray_resampling(
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
......@@ -265,8 +174,7 @@ std::vector<torch::Tensor> ray_resampling(
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
// torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_num_steps = torch::clamp(num_steps, 0, steps);
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
......@@ -293,52 +201,3 @@ std::vector<torch::Tensor> ray_resampling(
return {resample_packed_info, resample_starts, resample_ends};
}
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(pdfs);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(pdfs.ndimension() == 2 & pdfs.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_resamples = resample_starts.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor resample_pdfs = torch::zeros({n_resamples, 1}, pdfs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pdfs.scalar_type(),
"pdf_query",
([&]
{ pdf_query_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
pdfs.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>(),
// outputs
resample_pdfs.data_ptr<scalar_t>()); }));
return resample_pdfs;
}
......@@ -13,13 +13,6 @@ std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle);
std::vector<torch::Tensor> ray_marching_with_grid(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
......@@ -65,15 +58,6 @@ std::vector<torch::Tensor> ray_resampling(
torch::Tensor weights,
const int steps);
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
......@@ -160,9 +144,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
m.def("ray_marching_with_grid", &ray_marching_with_grid);
m.def("ray_resampling", &ray_resampling);
m.def("ray_pdf_query", &ray_pdf_query);
// rendering
m.def("is_cub_available", is_cub_available);
......
......@@ -76,78 +76,7 @@ inline __device__ __host__ float advance_to_next_voxel(
// Raymarching
// -------------------------------------------------------------------------------
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float near = t_min[0], far = t_max[0];
float dt_min = step_size;
float dt_max = 1e10f;
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
__global__ void ray_marching_with_grid_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
......@@ -260,74 +189,7 @@ __global__ void ray_marching_with_grid_kernel(
return;
}
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(t_min);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
const int n_rays = t_min.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, t_min.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
std::vector<torch::Tensor> ray_marching_with_grid(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
......@@ -368,7 +230,7 @@ std::vector<torch::Tensor> ray_marching_with_grid(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
......@@ -399,7 +261,7 @@ std::vector<torch::Tensor> ray_marching_with_grid(
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
......
......@@ -2,13 +2,15 @@ from typing import Callable, Optional, Tuple
import torch
import nerfacc.cuda as _C
from .contraction import ContractionType
from .grid import Grid
from .intersection import ray_aabb_intersect
from .sampling import proposal_resampling, sample_along_rays
from .vol_rendering import render_visibility
@torch.no_grad()
# @profile
def ray_marching(
# rays
rays_o: torch.Tensor,
......@@ -22,10 +24,6 @@ def ray_marching(
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns: Tuple[Callable, ...] = [],
proposal_n_samples: Tuple[int, ...] = [],
proposal_require_grads: bool = False,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
# rendering options
......@@ -130,8 +128,6 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
n_rays = rays_o.shape[0]
if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
if alpha_fn is not None and sigma_fn is not None:
......@@ -161,27 +157,65 @@ def ray_marching(
if stratified:
t_min = t_min + torch.rand_like(t_min) * render_step_size
ray_indices, t_starts, t_ends = sample_along_rays(
rays_o=rays_o,
rays_d=rays_d,
t_min=t_min,
t_max=t_max,
step_size=render_step_size,
cone_angle=cone_angle,
grid=grid,
# use grid for skipping if given
if grid is not None:
grid_roi_aabb = grid.roi_aabb
grid_binary = grid.binary
contraction_type = grid.contraction_type.to_cpp_version()
else:
grid_roi_aabb = torch.tensor(
[-1e10, -1e10, -1e10, 1e10, 1e10, 1e10],
dtype=torch.float32,
device=rays_o.device,
)
grid_binary = torch.ones(
[1, 1, 1], dtype=torch.bool, device=rays_o.device
)
contraction_type = ContractionType.AABB.to_cpp_version()
# marching with grid-based skipping
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# coontraction and grid
grid_roi_aabb.contiguous(),
grid_binary.contiguous(),
contraction_type,
# sampling
render_step_size,
cone_angle,
)
ray_indices, t_starts, t_ends, proposal_samples = proposal_resampling(
t_starts=t_starts,
t_ends=t_ends,
ray_indices=ray_indices,
n_rays=n_rays,
sigma_fn=sigma_fn,
proposal_sigma_fns=proposal_sigma_fns,
proposal_n_samples=proposal_n_samples,
proposal_require_grads=proposal_require_grads,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
return ray_indices, t_starts, t_ends, proposal_samples
return ray_indices, t_starts, t_ends
......@@ -99,7 +99,7 @@ def sample_along_rays(
@torch.no_grad()
def proposal_resampling(
def proposal_sampling_with_filter(
t_starts: torch.Tensor, # [n_samples, 1]
t_ends: torch.Tensor, # [n_samples, 1]
ray_indices: torch.Tensor, # [n_samples,]
......
......@@ -23,7 +23,7 @@ def rendering(
rgb_alpha_fn: Optional[Callable] = None,
# rendering options
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Render the rays through the radience field defined by `rgb_sigma_fn`.
This function is differentiable to the outputs of `rgb_sigma_fn` so it can
......@@ -126,7 +126,7 @@ def rendering(
if render_bkgd is not None:
colors = colors + render_bkgd * (1.0 - opacities)
return colors, opacities, depths, weights
return colors, opacities, depths
def accumulate_along_rays(
......@@ -142,7 +142,7 @@ def accumulate_along_rays(
Args:
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples, 1).
(n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples).
values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None.
......@@ -500,24 +500,21 @@ def render_visibility(
"""
assert (
alphas.dim() == 2 and alphas.shape[-1] == 1
), "alphas should be a 2D tensor with shape (n_samples, 1)."
visibility = alphas >= alpha_thre
if early_stop_eps > 0:
assert (
ray_indices is not None or packed_info is not None
), "Either ray_indices or packed_info should be provided."
if ray_indices is not None and _C.is_cub_available():
transmittance = _RenderingTransmittanceFromAlphaCUB.apply(
ray_indices, alphas
)
else:
if packed_info is None:
packed_info = pack_info(ray_indices, n_rays=n_rays)
transmittance = _RenderingTransmittanceFromAlphaNaive.apply(
packed_info, alphas
)
visibility = visibility & (transmittance >= early_stop_eps)
ray_indices is not None or packed_info is not None
), "Either ray_indices or packed_info should be provided."
if ray_indices is not None and _C.is_cub_available():
transmittance = _RenderingTransmittanceFromAlphaCUB.apply(
ray_indices, alphas
)
else:
if packed_info is None:
packed_info = pack_info(ray_indices, n_rays=n_rays)
transmittance = _RenderingTransmittanceFromAlphaNaive.apply(
packed_info, alphas
)
visibility = transmittance >= early_stop_eps
if alpha_thre > 0:
visibility = visibility & (alphas >= alpha_thre)
visibility = visibility.squeeze(-1)
return visibility
......
......@@ -124,7 +124,7 @@ def test_rendering():
t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0
_, _, _, _ = rendering(
_, _, _ = rendering(
t_starts,
t_ends,
ray_indices=ray_indices,
......
import pytest
import torch
from functorch import vmap
from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
batch_size = 128
eps = torch.finfo(torch.float32).eps
def _interp(x, xp, fp):
"""One-dimensional linear interpolation for monotonically increasing sample
points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: the :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: the :math:`x`-coordinates of the data points, must be increasing.
fp: the :math:`y`-coordinates of the data points, same length as `xp`.
Returns:
the interpolated values, same size as `x`.
"""
xp = xp.contiguous()
x = x.contiguous()
m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
b = fp[:-1] - (m * xp[:-1])
indices = torch.searchsorted(xp, x, right=True) - 1
indices = torch.clamp(indices, 0, len(m) - 1)
return m[indices] * x + b[indices]
def _integrate_weights(w):
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
The output's size on the last dimension is one greater than that of the input,
because we're computing the integral corresponding to the endpoints of a step
function, not the integral of the interior/bin values.
Args:
w: Tensor, which will be integrated along the last axis. This is assumed to
sum to 1 along the last axis, and this function will (silently) break if
that is not the case.
Returns:
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
"""
cw = torch.clamp(torch.cumsum(w[..., :-1], dim=-1), max=1)
shape = cw.shape[:-1] + (1,)
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
zeros = torch.zeros(shape, device=w.device)
ones = torch.ones(shape, device=w.device)
cw0 = torch.cat([zeros, cw, ones], dim=-1)
return cw0
def _invert_cdf(u, t, w_logits):
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
# Compute the PDF and CDF for each weight vector.
w = torch.softmax(w_logits, dim=-1)
# w = torch.exp(w_logits)
# w = w / torch.sum(w, dim=-1, keepdim=True)
cw = _integrate_weights(w)
# Interpolate into the inverse CDF.
t_new = vmap(_interp)(u, cw, t)
return t_new
def _resampling(t, w_logits, num_samples):
"""Piecewise-Constant PDF sampling from a step function.
Args:
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted).
w_logits: [..., num_bins], logits corresponding to bin weights.
num_samples: int, the number of samples.
returns:
t_samples: [..., num_samples], the sampled t values
"""
pad = 1 / (2 * num_samples)
u = torch.linspace(pad, 1.0 - pad - eps, num_samples, device=device)
u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
return _invert_cdf(u, t, w_logits)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
batch_size = 1024
num_bins = 128
num_samples = 128
t = torch.randn((batch_size, num_bins + 1), device=device)
t = torch.sort(t, dim=-1).values
w_logits = torch.randn((batch_size, num_bins), device=device) * 0.1
w = torch.softmax(w_logits, dim=-1)
masks = w_logits > 0
w_logits[~masks] = -torch.inf
t_samples = _resampling(t, w_logits, num_samples + 1)
t_starts = t[:, :-1][masks].unsqueeze(-1)
t_ends = t[:, 1:][masks].unsqueeze(-1)
w_logits = w_logits[masks].unsqueeze(-1)
w = w[masks].unsqueeze(-1)
num_steps = masks.long().sum(dim=-1)
cum_steps = torch.cumsum(num_steps, dim=0)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1).int()
_, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, w, num_samples
)
# print(
# (t_starts.view(batch_size, num_samples) - t_samples[:, :-1])
# .abs()
# .max(),
# (t_ends.view(batch_size, num_samples) - t_samples[:, 1:]).abs().max(),
# )
assert torch.allclose(
t_starts.view(batch_size, num_samples), t_samples[:, :-1], atol=1e-3
)
assert torch.allclose(
t_ends.view(batch_size, num_samples), t_samples[:, 1:], atol=1e-3
)
def test_pdf_query():
rays_o = torch.rand((1, 3), device=device)
rays_d = torch.randn((1, 3), device=device)
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
ray_indices, t_starts, t_ends = ray_marching(
......@@ -139,22 +18,15 @@ def test_pdf_query():
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
render_step_size=1e-3,
)
packed_info = pack_info(ray_indices, rays_o.shape[0])
weights = torch.rand((t_starts.shape[0], 1), device=device)
weights_new = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights,
packed_info,
t_starts + 0.3,
t_ends + 0.3,
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=32
)
assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
if __name__ == "__main__":
test_resampling()
test_pdf_query()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment