Unverified Commit 601c4c34 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Add new Utils: pack & unpack data; cdf sampling; query grid (#57)

* new utils

* proper test for pack

* add test por cdf and query occ

* add deprecated warning

* bump version

* fix list return to tuple
parent 40075646
nerfacc.pack\_data
==================
.. currentmodule:: nerfacc
.. autofunction:: pack_data
\ No newline at end of file
nerfacc.ray\_resampling
=======================
.. currentmodule:: nerfacc
.. autofunction:: ray_resampling
\ No newline at end of file
nerfacc.unpack\_data
====================
.. currentmodule:: nerfacc
.. autofunction:: unpack_data
\ No newline at end of file
nerfacc.unpack\_info
====================
.. currentmodule:: nerfacc
.. autofunction:: unpack_info
\ No newline at end of file
nerfacc.unpack\_to\_ray\_indices
================================
.. currentmodule:: nerfacc
.. autofunction:: unpack_to_ray_indices
\ No newline at end of file
...@@ -8,11 +8,14 @@ Utils ...@@ -8,11 +8,14 @@ Utils
:toctree: generated/ :toctree: generated/
ray_aabb_intersect ray_aabb_intersect
unpack_to_ray_indices unpack_info
accumulate_along_rays accumulate_along_rays
render_weight_from_density render_weight_from_density
render_weight_from_alpha render_weight_from_alpha
render_visibility render_visibility
ray_resampling
pack_data
unpack_data
\ No newline at end of file
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import warnings
from .cdf import ray_resampling
from .contraction import ContractionType, contract, contract_inv from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .pack import unpack_to_ray_indices from .pack import pack_data, unpack_data, unpack_info
from .ray_marching import ray_marching from .ray_marching import ray_marching
from .version import __version__ from .version import __version__
from .vol_rendering import ( from .vol_rendering import (
...@@ -16,19 +18,34 @@ from .vol_rendering import ( ...@@ -16,19 +18,34 @@ from .vol_rendering import (
rendering, rendering,
) )
# About to be deprecated
def unpack_to_ray_indices(*args, **kwargs):
warnings.warn(
"`unpack_to_ray_indices` will be deprecated. Please use `unpack_info` instead.",
DeprecationWarning,
stacklevel=2,
)
return unpack_info(*args, **kwargs)
__all__ = [ __all__ = [
"__version__",
"Grid", "Grid",
"OccupancyGrid", "OccupancyGrid",
"query_grid",
"ContractionType", "ContractionType",
"contract", "contract",
"contract_inv", "contract_inv",
"ray_aabb_intersect", "ray_aabb_intersect",
"ray_marching", "ray_marching",
"unpack_to_ray_indices",
"accumulate_along_rays", "accumulate_along_rays",
"render_visibility", "render_visibility",
"render_weight_from_alpha", "render_weight_from_alpha",
"render_weight_from_density", "render_weight_from_density",
"rendering", "rendering",
"__version__", "pack_data",
"unpack_data",
"unpack_info",
"ray_resampling",
] ]
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
from torch import Tensor
import nerfacc.cuda as _C
def ray_resampling(
packed_info: Tensor,
t_starts: Tensor,
t_ends: Tensor,
weights: Tensor,
n_samples: int,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Resample a set of rays based on the CDF of the weights.
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1).
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
n_samples (int): Number of samples per ray to resample.
Returns:
Resampled packed info (n_rays, 2), t_starts (n_samples, 1), and t_ends (n_samples, 1).
"""
(
resampled_packed_info,
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
n_samples,
)
return resampled_packed_info, resampled_t_starts, resampled_t_ends
...@@ -19,13 +19,17 @@ ContractionTypeGetter = _make_lazy_cuda_func("ContractionType") ...@@ -19,13 +19,17 @@ ContractionTypeGetter = _make_lazy_cuda_func("ContractionType")
contract = _make_lazy_cuda_func("contract") contract = _make_lazy_cuda_func("contract")
contract_inv = _make_lazy_cuda_func("contract_inv") contract_inv = _make_lazy_cuda_func("contract_inv")
query_occ = _make_lazy_cuda_func("query_occ") grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
unpack_to_ray_indices = _make_lazy_cuda_func("unpack_to_ray_indices") ray_resampling = _make_lazy_cuda_func("ray_resampling")
rendering_forward = _make_lazy_cuda_func("rendering_forward") rendering_forward = _make_lazy_cuda_func("rendering_forward")
rendering_backward = _make_lazy_cuda_func("rendering_backward") rendering_backward = _make_lazy_cuda_func("rendering_backward")
rendering_alphas_forward = _make_lazy_cuda_func("rendering_alphas_forward") rendering_alphas_forward = _make_lazy_cuda_func("rendering_alphas_forward")
rendering_alphas_backward = _make_lazy_cuda_func("rendering_alphas_backward") rendering_alphas_backward = _make_lazy_cuda_func("rendering_alphas_backward")
unpack_data = _make_lazy_cuda_func("unpack_data")
unpack_info = _make_lazy_cuda_func("unpack_info")
unpack_info_to_mask = _make_lazy_cuda_func("unpack_info_to_mask")
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *weights, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0];
const int resample_steps = resample_packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
weights += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t weights_sum = 0.0f;
for (int j = 0; j < steps; j++)
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
{
if (cdf_u < cdf_next)
{
// printf("cdf_u: %f, cdf_next: %f\n", cdf_u, cdf_next);
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
if (j < num_bins - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
cdf_u += cdf_step_size;
j += 1;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += (weights[idx] + padding_step) / weights_sum;
}
}
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
return;
}
// template <typename scalar_t>
// __global__ void cdf_resampling_kernel(
// const uint32_t n_rays,
// const int *packed_info, // input ray & point indices.
// const scalar_t *starts, // input start t
// const scalar_t *ends, // input end t
// const scalar_t *weights, // transmittance weights
// const int *resample_packed_info,
// scalar_t *resample_starts,
// scalar_t *resample_ends)
// {
// CUDA_GET_THREAD_ID(i, n_rays);
// // locate
// const int base = packed_info[i * 2 + 0]; // point idx start.
// const int steps = packed_info[i * 2 + 1]; // point idx shift.
// const int resample_base = resample_packed_info[i * 2 + 0];
// const int resample_steps = resample_packed_info[i * 2 + 1];
// if (steps == 0)
// return;
// starts += base;
// ends += base;
// weights += base;
// resample_starts += resample_base;
// resample_ends += resample_base;
// scalar_t cdf_step_size = 1.0f / resample_steps;
// // normalize weights **per ray**
// scalar_t weights_sum = 0.0f;
// for (int j = 0; j < steps; j++)
// weights_sum += weights[j];
// scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
// int idx = 0, j = 0;
// scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
// scalar_t cdf_u = 0.5f * cdf_step_size;
// while (cdf_u < 1.0f)
// {
// if (cdf_u < cdf_next)
// {
// // resample in this interval
// scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
// scalar_t resample_mid = (cdf_u - cdf_prev) * scaling + starts[idx];
// scalar_t resample_half_size = cdf_step_size * scaling * 0.5;
// resample_starts[j] = fmaxf(resample_mid - resample_half_size, starts[idx]);
// resample_ends[j] = fminf(resample_mid + resample_half_size, ends[idx]);
// // going further to next resample
// cdf_u += cdf_step_size;
// j += 1;
// }
// else
// {
// // go to next interval
// idx += 1;
// if (idx == steps)
// break;
// cdf_prev = cdf_next;
// cdf_next += (weights[idx] + padding_step) / weights_sum;
// }
// }
// if (j != resample_steps)
// {
// printf("Error: %d %d %f\n", j, resample_steps, weights_sum);
// }
// return;
// }
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(weights);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
int total_steps = resample_cum_steps[resample_cum_steps.size(0) - 1].item<int>();
torch::Tensor resample_starts = torch::zeros({total_steps, 1}, starts.options());
torch::Tensor resample_ends = torch::zeros({total_steps, 1}, ends.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(),
"ray_resampling",
([&]
{ cdf_resampling_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
// outputs
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>()); }));
return {resample_packed_info, resample_starts, resample_ends};
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "include/helpers_cuda.h" #include "include/helpers_cuda.h"
__global__ void ray_indices_kernel( __global__ void unpack_info_kernel(
// input // input
const int n_rays, const int n_rays,
const int *packed_info, const int *packed_info,
...@@ -27,7 +27,61 @@ __global__ void ray_indices_kernel( ...@@ -27,7 +27,61 @@ __global__ void ray_indices_kernel(
} }
} }
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) __global__ void unpack_info_to_mask_kernel(
// input
const int n_rays,
const int *packed_info,
const int n_samples,
// output
bool *masks) // [n_rays, n_samples]
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
masks += i * n_samples;
for (int j = 0; j < steps; ++j)
{
masks[j] = true;
}
}
template <typename scalar_t>
__global__ void unpack_data_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const int data_dim,
const scalar_t *data,
const int n_sampler_per_ray,
scalar_t *unpacked_data) // (n_rays, n_sampler_per_ray, data_dim)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
data += base * data_dim;
unpacked_data += i * n_sampler_per_ray * data_dim;
for (int j = 0; j < steps; j++)
{
for (int k = 0; k < data_dim; k++)
{
unpacked_data[j * data_dim + k] = data[j * data_dim + k];
}
}
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info)
{ {
DEVICE_GUARD(packed_info); DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info); CHECK_INPUT(packed_info);
...@@ -40,9 +94,71 @@ torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) ...@@ -40,9 +94,71 @@ torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info)
torch::Tensor ray_indices = torch::zeros( torch::Tensor ray_indices = torch::zeros(
{n_samples}, packed_info.options().dtype(torch::kInt32)); {n_samples}, packed_info.options().dtype(torch::kInt32));
ray_indices_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays, n_rays,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>()); ray_indices.data_ptr<int>());
return ray_indices; return ray_indices;
} }
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor masks = torch::zeros(
{n_rays, n_samples}, packed_info.options().dtype(torch::kBool));
unpack_info_to_mask_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
n_samples,
masks.data_ptr<bool>());
return masks;
}
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(data);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(data.ndimension() == 2);
const int n_rays = packed_info.size(0);
const int n_samples = data.size(0);
const int data_dim = data.size(1);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::zeros(
{n_rays, n_samples_per_ray, data_dim}, data.options());
AT_DISPATCH_ALL_TYPES(
data.scalar_type(),
"unpack_data",
([&]
{ unpack_data_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
data_dim,
data.data_ptr<scalar_t>(),
n_samples_per_ray,
// outputs
unpacked_data.data_ptr<scalar_t>()); }));
return unpacked_data;
}
...@@ -44,14 +44,17 @@ std::vector<torch::Tensor> ray_marching( ...@@ -44,14 +44,17 @@ std::vector<torch::Tensor> ray_marching(
const float step_size, const float step_size,
const float cone_angle); const float cone_angle);
torch::Tensor unpack_to_ray_indices( torch::Tensor unpack_info(
const torch::Tensor packed_info); const torch::Tensor packed_info);
torch::Tensor query_occ( torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor grid_query(
const torch::Tensor samples, const torch::Tensor samples,
// occupancy grid & contraction // occupancy grid & contraction
const torch::Tensor roi, const torch::Tensor roi,
const torch::Tensor grid_binary, const torch::Tensor grid_value,
const ContractionType type); const ContractionType type);
torch::Tensor contract( torch::Tensor contract(
...@@ -81,6 +84,18 @@ std::vector<torch::Tensor> rendering_alphas_forward( ...@@ -81,6 +84,18 @@ std::vector<torch::Tensor> rendering_alphas_forward(
float alpha_thre, float alpha_thre,
bool compression); bool compression);
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
// contraction // contraction
...@@ -92,16 +107,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -92,16 +107,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("contract_inv", &contract_inv); m.def("contract_inv", &contract_inv);
// grid // grid
m.def("query_occ", &query_occ); m.def("grid_query", &grid_query);
// marching // marching
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching); m.def("ray_marching", &ray_marching);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices); m.def("ray_resampling", &ray_resampling);
// rendering // rendering
m.def("rendering_forward", &rendering_forward); m.def("rendering_forward", &rendering_forward);
m.def("rendering_backward", &rendering_backward); m.def("rendering_backward", &rendering_backward);
m.def("rendering_alphas_forward", &rendering_alphas_forward); m.def("rendering_alphas_forward", &rendering_alphas_forward);
m.def("rendering_alphas_backward", &rendering_alphas_backward); m.def("rendering_alphas_backward", &rendering_alphas_backward);
// pack & unpack
m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask);
} }
\ No newline at end of file
...@@ -24,11 +24,12 @@ inline __device__ __host__ int grid_idx_at( ...@@ -24,11 +24,12 @@ inline __device__ __host__ int grid_idx_at(
return idx; return idx;
} }
inline __device__ __host__ bool grid_occupied_at( template <typename scalar_t>
inline __device__ __host__ scalar_t grid_occupied_at(
const float3 xyz, const float3 xyz,
const float3 roi_min, const float3 roi_max, const float3 roi_min, const float3 roi_max,
ContractionType type, ContractionType type,
const int3 grid_res, const bool *grid_binary) const int3 grid_res, const scalar_t *grid_value)
{ {
if (type == ContractionType::AABB && if (type == ContractionType::AABB &&
(xyz.x < roi_min.x || xyz.x > roi_max.x || (xyz.x < roi_min.x || xyz.x > roi_max.x ||
...@@ -40,7 +41,7 @@ inline __device__ __host__ bool grid_occupied_at( ...@@ -40,7 +41,7 @@ inline __device__ __host__ bool grid_occupied_at(
float3 xyz_unit = apply_contraction( float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type); xyz, roi_min, roi_max, type);
int idx = grid_idx_at(xyz_unit, grid_res); int idx = grid_idx_at(xyz_unit, grid_res);
return grid_binary[idx]; return grid_value[idx];
} }
// dda like step // dda like step
...@@ -283,6 +284,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -283,6 +284,7 @@ std::vector<torch::Tensor> ray_marching(
// Query the occupancy grid // Query the occupancy grid
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
template <typename scalar_t>
__global__ void query_occ_kernel( __global__ void query_occ_kernel(
// rays info // rays info
const uint32_t n_samples, const uint32_t n_samples,
...@@ -290,10 +292,10 @@ __global__ void query_occ_kernel( ...@@ -290,10 +292,10 @@ __global__ void query_occ_kernel(
// occupancy grid & contraction // occupancy grid & contraction
const float *roi, const float *roi,
const int3 grid_res, const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z) const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z)
const ContractionType type, const ContractionType type,
// outputs // outputs
bool *occs) scalar_t *occs)
{ {
CUDA_GET_THREAD_ID(i, n_samples); CUDA_GET_THREAD_ID(i, n_samples);
...@@ -305,15 +307,15 @@ __global__ void query_occ_kernel( ...@@ -305,15 +307,15 @@ __global__ void query_occ_kernel(
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]); const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]); const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_binary); *occs = grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_value);
return; return;
} }
torch::Tensor query_occ( torch::Tensor grid_query(
const torch::Tensor samples, const torch::Tensor samples,
// occupancy grid & contraction // occupancy grid & contraction
const torch::Tensor roi, const torch::Tensor roi,
const torch::Tensor grid_binary, const torch::Tensor grid_value,
const ContractionType type) const ContractionType type)
{ {
DEVICE_GUARD(samples); DEVICE_GUARD(samples);
...@@ -321,23 +323,28 @@ torch::Tensor query_occ( ...@@ -321,23 +323,28 @@ torch::Tensor query_occ(
const int n_samples = samples.size(0); const int n_samples = samples.size(0);
const int3 grid_res = make_int3( const int3 grid_res = make_int3(
grid_binary.size(0), grid_binary.size(1), grid_binary.size(2)); grid_value.size(0), grid_value.size(1), grid_value.size(2));
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::zeros( torch::Tensor occs = torch::zeros({n_samples}, grid_value.options());
{n_samples}, samples.options().dtype(torch::kBool));
query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Bool,
occs.scalar_type(),
"grid_query",
([&]
{ query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples, n_samples,
samples.data_ptr<float>(), samples.data_ptr<float>(),
// grid // grid
roi.data_ptr<float>(), roi.data_ptr<float>(),
grid_res, grid_res,
grid_binary.data_ptr<bool>(), grid_value.data_ptr<scalar_t>(),
type, type,
// outputs // outputs
occs.data_ptr<bool>()); occs.data_ptr<scalar_t>()); }));
return occs; return occs;
} }
...@@ -7,12 +7,46 @@ from typing import Callable, List, Union ...@@ -7,12 +7,46 @@ from typing import Callable, List, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import nerfacc.cuda as _C
from .contraction import ContractionType, contract_inv from .contraction import ContractionType, contract_inv
# TODO: add this to the dependency # TODO: check torch.scatter_reduce_
# from torch_scatter import scatter_max # from torch_scatter import scatter_max
@torch.no_grad()
def query_grid(
samples: torch.Tensor,
grid_roi: torch.Tensor,
grid_values: torch.Tensor,
grid_type: ContractionType,
):
"""Query grid values given coordinates.
Args:
samples: (n_samples, 3) tensor of coordinates.
grid_roi: (6,) region of interest of the grid. Usually it should be
accquired from the grid itself using `grid.roi_aabb`.
grid_values: A 3D tensor of grid values in the shape of (resx, resy, resz).
grid_type: Contraction type of the grid. Usually it should be
accquired from the grid itself using `grid.contraction_type`.
Returns:
(n_samples) values for those samples queried from the grid.
"""
assert samples.dim() == 2 and samples.size(-1) == 3
assert grid_roi.dim() == 1 and grid_roi.size(0) == 6
assert grid_values.dim() == 3
assert isinstance(grid_type, ContractionType)
return _C.grid_query(
samples.contiguous(),
grid_roi.contiguous(),
grid_values.contiguous(),
grid_type.to_cpp_version(),
)
class Grid(nn.Module): class Grid(nn.Module):
"""An abstract Grid class. """An abstract Grid class.
...@@ -242,6 +276,23 @@ class OccupancyGrid(Grid): ...@@ -242,6 +276,23 @@ class OccupancyGrid(Grid):
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
) )
@torch.no_grad()
def query_occ(self, samples: torch.Tensor) -> torch.Tensor:
"""Query the occupancy field at the given samples.
Args:
samples: Samples in the world coordinates. (n_samples, 3)
Returns:
Occupancy values at the given samples. (n_samples,)
"""
return query_grid(
samples,
self._roi_aabb,
self.occs.reshape(self.resolution.tolist()),
self.contraction_type,
)
def _meshgrid3d( def _meshgrid3d(
res: torch.Tensor, device: Union[torch.device, str] = "cpu" res: torch.Tensor, device: Union[torch.device, str] = "cpu"
......
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -8,8 +9,42 @@ from torch import Tensor ...@@ -8,8 +9,42 @@ from torch import Tensor
import nerfacc.cuda as _C import nerfacc.cuda as _C
def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Pack per-ray data (n_rays, n_samples, D) to (all_samples, D) based on mask.
Args:
data: Tensor with shape (n_rays, n_samples, D).
mask: Boolen tensor with shape (n_rays, n_samples).
Returns:
Tuple of Tensors including packed data (all_samples, D), \
and packed_info (n_rays, 2) which stores the start index of the sample,
and the number of samples kept for each ray. \
Examples:
.. code-block:: python
data = torch.rand((10, 3, 4), device="cuda:0")
mask = data.rand((10, 3), dtype=torch.bool, device="cuda:0")
packed_data, packed_info = pack(data, mask)
print(packed_data.shape, packed_info.shape)
"""
assert data.dim() == 3, "data must be with shape of (n_rays, n_samples, D)."
assert (
mask.shape == data.shape[:2]
), "mask must be with shape of (n_rays, n_samples)."
assert mask.dtype == torch.bool, "mask must be a boolean tensor."
packed_data = data[mask]
num_steps = mask.long().sum(dim=-1)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.long)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
return packed_data, packed_info
@torch.no_grad() @torch.no_grad()
def unpack_to_ray_indices(packed_info: Tensor) -> Tensor: def unpack_info(packed_info: Tensor) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data. """Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note: Note:
...@@ -36,13 +71,85 @@ def unpack_to_ray_indices(packed_info: Tensor) -> Tensor: ...@@ -36,13 +71,85 @@ def unpack_to_ray_indices(packed_info: Tensor) -> Tensor:
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1]) # torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape) print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info. # Unpack per-ray info to per-sample info.
ray_indices = unpack_to_ray_indices(packed_info) ray_indices = unpack_info(packed_info)
# torch.Size([115200]) torch.int64 # torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype) print(ray_indices.shape, ray_indices.dtype)
""" """
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda: if packed_info.is_cuda:
ray_indices = _C.unpack_to_ray_indices(packed_info.contiguous().int()) ray_indices = _C.unpack_info(packed_info.contiguous().int())
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long() return ray_indices.long()
def unpack_data(
packed_info: Tensor,
data: Tensor,
n_samples: Optional[int] = None,
) -> Tensor:
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D).
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
data: Packed data to unpack. Tensor with shape (n_samples, D).
n_samples (int): Optional Number of samples per ray. If not provided, it \
will be inferred from the packed_info.
Returns:
Unpacked data (n_rays, n_samples, D).
Examples:
.. code-block:: python
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with aabb.
scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device="cuda:0")
packed_info, t_starts, t_ends = ray_marching(
rays_o, rays_d, scene_aabb=scene_aabb, render_step_size=1e-2
)
print(t_starts.shape) # torch.Size([all_samples, 1])
t_starts = unpack_data(packed_info, t_starts, n_samples=1024)
print(t_starts.shape) # torch.Size([128, 1024, 1])
"""
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
assert (
data.dim() == 2
), "data must be a 2D tensor with shape (n_samples, D)."
if n_samples is None:
n_samples = packed_info[:, 1].max().item()
return _UnpackData.apply(packed_info, data, n_samples)
class _UnpackData(torch.autograd.Function):
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D)."""
@staticmethod
def forward(ctx, packed_info: Tensor, data: Tensor, n_samples: int):
# shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous().int()
data = data.contiguous()
if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info)
ctx.n_samples = n_samples
return _C.unpack_data(packed_info, data, n_samples)
@staticmethod
def backward(ctx, grad: Tensor):
# shape of the grad should be (n_rays, n_samples, D)
packed_info = ctx.saved_tensors[0]
n_samples = ctx.n_samples
mask = _C.unpack_info_to_mask(packed_info, n_samples)
packed_grad = grad[mask].contiguous()
return None, packed_grad, None
...@@ -7,7 +7,7 @@ import nerfacc.cuda as _C ...@@ -7,7 +7,7 @@ import nerfacc.cuda as _C
from .contraction import ContractionType from .contraction import ContractionType
from .grid import Grid from .grid import Grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .pack import unpack_to_ray_indices from .pack import unpack_info
from .vol_rendering import render_visibility from .vol_rendering import render_visibility
...@@ -87,7 +87,7 @@ def ray_marching( ...@@ -87,7 +87,7 @@ def ray_marching(
.. code-block:: python .. code-block:: python
import torch import torch
from nerfacc import OccupancyGrid, ray_marching, unpack_to_ray_indices from nerfacc import OccupancyGrid, ray_marching, unpack_info
device = "cuda:0" device = "cuda:0"
batch_size = 128 batch_size = 128
...@@ -121,7 +121,7 @@ def ray_marching( ...@@ -121,7 +121,7 @@ def ray_marching(
) )
# Convert t_starts and t_ends to sample locations. # Convert t_starts and t_ends to sample locations.
ray_indices = unpack_to_ray_indices(packed_info) ray_indices = unpack_info(packed_info)
t_mid = (t_starts + t_ends) / 2.0 t_mid = (t_starts + t_ends) / 2.0
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
...@@ -186,7 +186,7 @@ def ray_marching( ...@@ -186,7 +186,7 @@ def ray_marching(
# skip invisible space # skip invisible space
if sigma_fn is not None: if sigma_fn is not None:
# Query sigma without gradients # Query sigma without gradients
ray_indices = unpack_to_ray_indices(packed_info) ray_indices = unpack_info(packed_info)
sigmas = sigma_fn(t_starts, t_ends, ray_indices) sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert ( assert (
sigmas.shape == t_starts.shape sigmas.shape == t_starts.shape
......
...@@ -9,7 +9,7 @@ from torch import Tensor ...@@ -9,7 +9,7 @@ from torch import Tensor
import nerfacc.cuda as _C import nerfacc.cuda as _C
from .pack import unpack_to_ray_indices from .pack import unpack_info
def rendering( def rendering(
...@@ -77,7 +77,7 @@ def rendering( ...@@ -77,7 +77,7 @@ def rendering(
""" """
n_rays = packed_info.shape[0] n_rays = packed_info.shape[0]
ray_indices = unpack_to_ray_indices(packed_info) ray_indices = unpack_info(packed_info)
# Query sigma and color with gradients # Query sigma and color with gradients
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices)
...@@ -129,7 +129,7 @@ def accumulate_along_rays( ...@@ -129,7 +129,7 @@ def accumulate_along_rays(
weights: Volumetric rendering weights for those samples. Tensor with shape \ weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,). (n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \ ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \
It can be obtained from `unpack_to_ray_indices(packed_info)`. It can be obtained from `unpack_info(packed_info)`.
values: The values to be accmulated. Tensor with shape (n_samples, D). If \ values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None. None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \ n_rays: Total number of rays. This will decide the shape of the ouputs. If \
......
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nerfacc" name = "nerfacc"
version = "0.1.6" version = "0.1.7"
description = "A General NeRF Acceleration Toolbox." description = "A General NeRF Acceleration Toolbox."
readme = "README.md" readme = "README.md"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
......
...@@ -9,7 +9,7 @@ device = "cuda:0" ...@@ -9,7 +9,7 @@ device = "cuda:0"
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor: def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
"""Pesudo occupancy function: (N, 3) -> (N, 1).""" """Pesudo occupancy function: (N, 3) -> (N, 1)."""
return torch.rand_like(x[:, :1]) return ((x - 0.5).norm(dim=-1, keepdim=True) < 0.5).float()
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
...@@ -21,5 +21,16 @@ def test_occ_grid(): ...@@ -21,5 +21,16 @@ def test_occ_grid():
assert occ_grid.binary.shape == (128, 128, 128) assert occ_grid.binary.shape == (128, 128, 128)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_query_grid():
roi_aabb = [0, 0, 0, 1, 1, 1]
occ_grid = OccupancyGrid(roi_aabb=roi_aabb, resolution=128).to(device)
occ_grid.every_n_step(0, occ_eval_fn, occ_thre=0.1)
samples = torch.rand((100, 3), device=device)
occs = occ_grid.query_occ(samples)
assert occs.shape == (100,)
if __name__ == "__main__": if __name__ == "__main__":
test_occ_grid() test_occ_grid()
test_query_grid()
import pytest import pytest
import torch import torch
from nerfacc import unpack_to_ray_indices from nerfacc import pack_data, unpack_data, unpack_info
device = "cuda:0" device = "cuda:0"
batch_size = 32 batch_size = 32
eps = 1e-6 eps = 1e-6
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_pack_data():
n_rays = 2
n_samples = 3
data = torch.rand((n_rays, n_samples, 2), device=device, requires_grad=True)
mask = torch.rand((n_rays, n_samples), device=device) > 0.5
packed_data, packed_info = pack_data(data, mask)
unpacked_data = unpack_data(packed_info, packed_data, n_samples)
unpacked_data.sum().backward()
assert (data.grad[mask] == 1).all()
assert torch.allclose(
unpacked_data.sum(dim=1), (data * mask[..., None]).sum(dim=1)
)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_unpack_info(): def test_unpack_info():
packed_info = torch.tensor( packed_info = torch.tensor(
...@@ -16,9 +31,10 @@ def test_unpack_info(): ...@@ -16,9 +31,10 @@ def test_unpack_info():
ray_indices_tgt = torch.tensor( ray_indices_tgt = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int64, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) )
ray_indices = unpack_to_ray_indices(packed_info) ray_indices = unpack_info(packed_info)
assert torch.allclose(ray_indices, ray_indices_tgt) assert torch.allclose(ray_indices, ray_indices_tgt)
if __name__ == "__main__": if __name__ == "__main__":
test_pack_data()
test_unpack_info() test_unpack_info()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment