Unverified Commit b53f8a4e authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Support multi-res occ grid & prop net (#176)

* multres grid

* prop

* benchmark with prop and occ

* benchmark blender with weight_decay

* docs

* bump version
parent 82fd69c7
...@@ -126,7 +126,8 @@ torch::Tensor unpack_info_to_mask( ...@@ -126,7 +126,8 @@ torch::Tensor unpack_info_to_mask(
torch::Tensor unpack_data( torch::Tensor unpack_data(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor data, torch::Tensor data,
int n_samples_per_ray) int n_samples_per_ray,
float pad_value)
{ {
DEVICE_GUARD(packed_info); DEVICE_GUARD(packed_info);
...@@ -143,8 +144,8 @@ torch::Tensor unpack_data( ...@@ -143,8 +144,8 @@ torch::Tensor unpack_data(
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::zeros( torch::Tensor unpacked_data = torch::full(
{n_rays, n_samples_per_ray, data_dim}, data.options()); {n_rays, n_samples_per_ray, data_dim}, pad_value, data.options());
AT_DISPATCH_ALL_TYPES( AT_DISPATCH_ALL_TYPES(
data.scalar_type(), data.scalar_type(),
......
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/util/MaybeOwned.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include "include/helpers_cuda.h"
namespace F = torch::nn::functional;
template <typename scalar_t>
inline __device__ __host__ scalar_t ceil_div(scalar_t a, scalar_t b)
{
return (a + b - 1) / b;
}
// Taken from:
// https://github.com/pytorch/pytorch/blob/8f1c3c68d3aba5c8898bfb3144988aab6776d549/aten/src/ATen/native/cuda/Bucketization.cu
template<typename input_t>
__device__ int64_t lower_bound(const input_t *data_ss, int64_t start, int64_t end, const input_t val, const int64_t *data_sort) {
// sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
// i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
const int64_t orig_start = start;
while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const input_t mid_val = data_sort ? data_ss[orig_start + data_sort[mid]] : data_ss[mid];
if (!(mid_val >= val)) {
start = mid + 1;
}
else {
end = mid;
}
}
return start;
}
template <typename scalar_t>
__device__ int64_t upper_bound(const scalar_t *data_ss, int64_t start, int64_t end, const scalar_t val, const int64_t *data_sort)
{
// sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
// i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
const int64_t orig_start = start;
while (start < end)
{
const int64_t mid = start + ((end - start) >> 1);
const scalar_t mid_val = data_sort ? data_ss[orig_start + data_sort[mid]] : data_ss[mid];
if (!(mid_val > val))
{
start = mid + 1;
}
else
{
end = mid;
}
}
return start;
}
template <typename scalar_t>
__global__ void pdf_sampling_kernel(
at::PhiloxCudaState philox_args,
const int64_t n_samples_in, // n_samples_in or whatever (not used)
const int64_t *info_ts, // nullptr or [n_rays, 2]
const scalar_t *ts, // [n_rays, n_samples_in] or packed [all_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in] or packed [all_samples_in]
const bool *masks, // [n_rays]
const bool stratified,
const bool single_jitter,
// outputs
const int64_t numel,
const int64_t n_samples_out,
scalar_t *ts_out) // [n_rays, n_samples_out]
{
int64_t n_bins_out = n_samples_out - 1;
scalar_t u_pad, u_interval;
if (stratified) {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
} else {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
}
// = stratified ? 1.0f / n_samples_out : (1.0f - 2 * pad) / (n_samples_out - 1);
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
{
int64_t ray_id = tid / n_samples_out;
int64_t sample_id = tid % n_samples_out;
if (masks != nullptr && !masks[ray_id]) {
// This ray is to be skipped.
// Be careful the ts needs to be initialized properly.
continue;
}
int64_t start_bd, end_bd;
if (info_ts == nullptr)
{
// no packing, the input is already [n_rays, n_samples_in]
start_bd = ray_id * n_samples_in;
end_bd = start_bd + n_samples_in;
}
else
{
// packed input, the input is [all_samples_in]
start_bd = info_ts[ray_id * 2];
end_bd = start_bd + info_ts[ray_id * 2 + 1];
if (start_bd == end_bd) {
// This ray is empty, so there is nothing to sample from.
// Be careful the ts needs to be initialized properly.
continue;
}
}
scalar_t u = u_pad + sample_id * u_interval;
if (stratified)
{
auto seeds = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
int64_t rand_seq_id = single_jitter ? ray_id : tid;
curand_init(std::get<0>(seeds), rand_seq_id, std::get<1>(seeds), &state);
float rand = curand_uniform(&state);
u -= rand * u_interval;
u = max(u, static_cast<scalar_t>(0.0f));
}
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos = upper_bound<scalar_t>(accum_weights, start_bd, end_bd, u, nullptr);
int64_t p0 = min(max(pos - 1, start_bd), end_bd - 1);
int64_t p1 = min(max(pos, start_bd), end_bd - 1);
scalar_t start_u = accum_weights[p0];
scalar_t end_u = accum_weights[p1];
scalar_t start_t = ts[p0];
scalar_t end_t = ts[p1];
if (p0 == p1) {
if (p0 == end_bd - 1)
ts_out[tid] = end_t;
else
ts_out[tid] = start_t;
} else if (end_u - start_u < 1e-20f) {
ts_out[tid] = (start_t + end_t) * 0.5f;
} else {
scalar_t scaling = (end_t - start_t) / (end_u - start_u);
scalar_t t = (u - start_u) * scaling + start_t;
ts_out[tid] = t;
}
}
}
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt) // [n_rays]
{
DEVICE_GUARD(ts);
CHECK_INPUT(ts);
CHECK_INPUT(weights);
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
TORCH_CHECK(ts.size(1) == weights.size(1) + 1);
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
if (padding > 0.f)
{
weights = weights + padding;
}
weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
torch::Tensor ts_out = torch::full({ts.size(0), n_samples}, -1.0f, ts.options());
int64_t numel = ts_out.numel();
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(4);
}
AT_DISPATCH_ALL_TYPES(
ts.scalar_type(),
"pdf_sampling",
([&]
{ pdf_sampling_kernel<scalar_t><<<grid, block, 0, stream>>>(
rng_engine_inputs,
ts.size(1), /* n_samples_in */
nullptr, /* info_ts */
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr, /* masks */
stratified,
single_jitter,
numel, /* numel */
ts_out.size(1), /* n_samples_out */
ts_out.data_ptr<scalar_t>() /* ts_out */
); }));
return ts_out; // [n_rays, n_samples_out]
}
template <typename scalar_t>
__global__ void pdf_readout_kernel(
const int64_t n_rays,
// keys
const int64_t n_samples_in,
const scalar_t *ts, // [n_rays, n_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in]
const bool *masks, // [n_rays]
// query
const int64_t n_samples_out,
const scalar_t *ts_out, // [n_rays, n_samples_out]
const bool *masks_out, // [n_rays]
scalar_t *weights_out) // [n_rays, n_samples_out - 1]
{
int64_t n_bins_out = n_samples_out - 1;
int64_t numel = n_bins_out * n_rays;
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
{
int64_t ray_id = tid / n_bins_out;
if (masks_out != nullptr && !masks_out[ray_id]) {
// We don't care about this query ray.
weights_out[tid] = 0.0f;
continue;
}
if (masks != nullptr && !masks[ray_id]) {
// We don't have the values for the key ray. In this case we consider the key ray
// is all-zero.
weights_out[tid] = 0.0f;
continue;
}
// search range in ts
int64_t start_bd = ray_id * n_samples_in;
int64_t end_bd = start_bd + n_samples_in;
// index in ts_out
int64_t id0 = tid + ray_id;
int64_t id1 = id0 + 1;
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos0 = upper_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id0], nullptr);
pos0 = max(min(pos0 - 1, end_bd-1), start_bd);
// searchsorted with "left" option:
// i.e. accum_weights[pos - 1] < u <= accum_weights[pos]
int64_t pos1 = lower_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id1], nullptr);
pos1 = max(min(pos1, end_bd-1), start_bd);
// outer
scalar_t outer = accum_weights[pos1] - accum_weights[pos0];
weights_out[tid] = outer;
}
}
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt) // [n_rays, n_samples_out]
{
DEVICE_GUARD(ts);
CHECK_INPUT(ts);
CHECK_INPUT(weights);
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
int64_t n_rays = ts.size(0);
int64_t n_samples_in = ts.size(1);
int64_t n_samples_out = ts_out.size(1);
int64_t n_bins_out = n_samples_out - 1;
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
c10::MaybeOwned<torch::Tensor> masks_out_maybe_owned = at::borrow_from_optional_tensor(masks_out_opt);
const torch::Tensor& masks_out = *masks_out_maybe_owned;
// weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
torch::Tensor weights_out = torch::empty({n_rays, n_bins_out}, weights.options());
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t numel = weights_out.numel();
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
AT_DISPATCH_ALL_TYPES(
weights.scalar_type(),
"pdf_readout",
([&]
{ pdf_readout_kernel<scalar_t><<<grid, block, 0, stream>>>(
n_rays,
n_samples_in,
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr,
n_samples_out,
ts_out.data_ptr<scalar_t>(), /* ts_out */
masks_out.defined() ? masks_out.data_ptr<bool>() : nullptr,
weights_out.data_ptr<scalar_t>()
); }));
return weights_out; // [n_rays, n_bins_out]
}
...@@ -51,17 +51,11 @@ torch::Tensor contract_inv( ...@@ -51,17 +51,11 @@ torch::Tensor contract_inv(
const torch::Tensor roi, const torch::Tensor roi,
const ContractionType type); const ContractionType type);
std::vector<torch::Tensor> ray_resampling(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor weights,
const int steps);
torch::Tensor unpack_data( torch::Tensor unpack_data(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor data, torch::Tensor data,
int n_samples_per_ray); int n_samples_per_ray,
float pad_value);
// cub implementations: parallel across samples // cub implementations: parallel across samples
bool is_cub_available() { bool is_cub_available() {
...@@ -128,6 +122,25 @@ torch::Tensor weight_from_alpha_backward_naive( ...@@ -128,6 +122,25 @@ torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor alphas); torch::Tensor alphas);
// pdf
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt); // [n_rays]
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt); // [n_rays, n_samples_out]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
// contraction // contraction
...@@ -144,7 +157,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -144,7 +157,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// marching // marching
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching); m.def("ray_marching", &ray_marching);
m.def("ray_resampling", &ray_resampling);
// rendering // rendering
m.def("is_cub_available", is_cub_available); m.def("is_cub_available", is_cub_available);
...@@ -167,4 +179,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -167,4 +179,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("unpack_data", &unpack_data); m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info); m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask); m.def("unpack_info_to_mask", &unpack_info_to_mask);
// pdf
m.def("pdf_sampling", &pdf_sampling);
m.def("pdf_readout", &pdf_readout);
} }
\ No newline at end of file
...@@ -13,6 +13,32 @@ inline __device__ __host__ float calc_dt( ...@@ -13,6 +13,32 @@ inline __device__ __host__ float calc_dt(
return clamp(t * cone_angle, dt_min, dt_max); return clamp(t * cone_angle, dt_min, dt_max);
} }
inline __device__ __host__ int mip_level(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
if (type != ContractionType::AABB)
{
// mip level should be always zero if not using AABB
return 0;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, ContractionType::AABB);
float3 scale = fabs(xyz_unit - 0.5);
float maxval = fmaxf(fmaxf(scale.x, scale.y), scale.z);
// if maxval is almost zero, it will trigger frexpf to output 0
// for exponent, which is not what we want.
maxval = fmaxf(maxval, 0.1);
int exponent;
frexpf(maxval, &exponent);
int mip = max(0, exponent + 1);
return mip;
}
inline __device__ __host__ int grid_idx_at( inline __device__ __host__ int grid_idx_at(
const float3 xyz_unit, const int3 grid_res) const float3 xyz_unit, const int3 grid_res)
{ {
...@@ -28,43 +54,49 @@ template <typename scalar_t> ...@@ -28,43 +54,49 @@ template <typename scalar_t>
inline __device__ __host__ scalar_t grid_occupied_at( inline __device__ __host__ scalar_t grid_occupied_at(
const float3 xyz, const float3 xyz,
const float3 roi_min, const float3 roi_max, const float3 roi_min, const float3 roi_max,
ContractionType type, ContractionType type, int mip,
const int3 grid_res, const scalar_t *grid_value) const int grid_nlvl, const int3 grid_res, const scalar_t *grid_value)
{ {
if (type == ContractionType::AABB && if (type == ContractionType::AABB && mip >= grid_nlvl)
(xyz.x < roi_min.x || xyz.x > roi_max.x ||
xyz.y < roi_min.y || xyz.y > roi_max.y ||
xyz.z < roi_min.z || xyz.z > roi_max.z))
{ {
return false; return false;
} }
float3 xyz_unit = apply_contraction( float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type); xyz, roi_min, roi_max, type);
int idx = grid_idx_at(xyz_unit, grid_res);
xyz_unit = (xyz_unit - 0.5) * scalbnf(1.0f, -mip) + 0.5;
int idx = grid_idx_at(xyz_unit, grid_res) + mip * grid_res.x * grid_res.y * grid_res.z;
return grid_value[idx]; return grid_value[idx];
} }
// dda like step // dda like step
inline __device__ __host__ float distance_to_next_voxel( inline __device__ __host__ float distance_to_next_voxel(
const float3 xyz, const float3 dir, const float3 inv_dir, const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res) const float3 roi_min, const float3 roi_max, const int3 grid_res)
{ {
float scaling = scalbnf(1.0f, mip);
float3 _roi_mid = (roi_min + roi_max) * 0.5;
float3 _roi_rad = (roi_max - roi_min) * 0.5;
float3 _roi_min = _roi_mid - _roi_rad * scaling;
float3 _roi_max = _roi_mid + _roi_rad * scaling;
float3 _occ_res = make_float3(grid_res); float3 _occ_res = make_float3(grid_res);
float3 _xyz = roi_to_unit(xyz, roi_min, roi_max) * _occ_res; float3 _xyz = roi_to_unit(xyz, _roi_min, _roi_max) * _occ_res;
float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (roi_max - roi_min); float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (_roi_max - _roi_min);
float t = min(min(txyz.x, txyz.y), txyz.z); float t = min(min(txyz.x, txyz.y), txyz.z);
return fmaxf(t, 0.0f); return fmaxf(t, 0.0f);
} }
inline __device__ __host__ float advance_to_next_voxel( inline __device__ __host__ float advance_to_next_voxel(
const float t, const float dt_min, const float t, const float dt_min,
const float3 xyz, const float3 dir, const float3 inv_dir, const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res, const float far) const float3 roi_min, const float3 roi_max, const int3 grid_res, const float far)
{ {
// Regular stepping (may be slower but matches non-empty space) // Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel( float t_target = t + distance_to_next_voxel(
xyz, dir, inv_dir, roi_min, roi_max, grid_res); xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res);
t_target = min(t_target, far); t_target = min(t_target, far);
float _t = t; float _t = t;
do do
...@@ -87,6 +119,7 @@ __global__ void ray_marching_kernel( ...@@ -87,6 +119,7 @@ __global__ void ray_marching_kernel(
const float *t_max, // shape (n_rays,) const float *t_max, // shape (n_rays,)
// occupancy grid & contraction // occupancy grid & contraction
const float *roi, const float *roi,
const int grid_nlvl,
const int3 grid_res, const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z) const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type, const ContractionType type,
...@@ -132,9 +165,20 @@ __global__ void ray_marching_kernel( ...@@ -132,9 +165,20 @@ __global__ void ray_marching_kernel(
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]); const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]); const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float grid_cell_scale = fmaxf(fmaxf(
(roi_max.x - roi_min.x) / grid_res.x * scalbnf(1.732f, grid_nlvl - 1),
(roi_max.y - roi_min.y) / grid_res.y * scalbnf(1.732f, grid_nlvl - 1)),
(roi_max.z - roi_min.z) / grid_res.z * scalbnf(1.732f, grid_nlvl - 1));
// TODO: compute dt_max from occ resolution. // TODO: compute dt_max from occ resolution.
float dt_min = step_size; float dt_min = step_size;
float dt_max = 1e10f; float dt_max;
if (type == ContractionType::AABB) {
// compute dt_max from occ grid resolution.
dt_max = grid_cell_scale;
} else {
dt_max = 1e10f;
}
int j = 0; int j = 0;
float t0 = near; float t0 = near;
...@@ -146,7 +190,13 @@ __global__ void ray_marching_kernel( ...@@ -146,7 +190,13 @@ __global__ void ray_marching_kernel(
{ {
// current center // current center
const float3 xyz = origin + t_mid * dir; const float3 xyz = origin + t_mid * dir;
if (grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_binary)) // current mip level
const int mip = mip_level(xyz, roi_min, roi_max, type);
if (mip >= grid_nlvl) {
// out of grid
break;
}
if (grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_binary))
{ {
if (!is_first_round) if (!is_first_round)
{ {
...@@ -168,7 +218,7 @@ __global__ void ray_marching_kernel( ...@@ -168,7 +218,7 @@ __global__ void ray_marching_kernel(
case ContractionType::AABB: case ContractionType::AABB:
// no contraction // no contraction
t_mid = advance_to_next_voxel( t_mid = advance_to_next_voxel(
t_mid, dt_min, xyz, dir, inv_dir, roi_min, roi_max, grid_res, far); t_mid, dt_min, xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res, far);
dt = calc_dt(t_mid, cone_angle, dt_min, dt_max); dt = calc_dt(t_mid, cone_angle, dt_min, dt_max);
t0 = t_mid - dt * 0.5f; t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f; t1 = t_mid + dt * 0.5f;
...@@ -218,11 +268,12 @@ std::vector<torch::Tensor> ray_marching( ...@@ -218,11 +268,12 @@ std::vector<torch::Tensor> ray_marching(
TORCH_CHECK(t_min.ndimension() == 1) TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1) TORCH_CHECK(t_max.ndimension() == 1)
TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6) TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6)
TORCH_CHECK(grid_binary.ndimension() == 3) TORCH_CHECK(grid_binary.ndimension() == 4)
const int n_rays = rays_o.size(0); const int n_rays = rays_o.size(0);
const int grid_nlvl = grid_binary.size(0);
const int3 grid_res = make_int3( const int3 grid_res = make_int3(
grid_binary.size(0), grid_binary.size(1), grid_binary.size(2)); grid_binary.size(1), grid_binary.size(2), grid_binary.size(3));
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
...@@ -241,6 +292,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -241,6 +292,7 @@ std::vector<torch::Tensor> ray_marching(
t_max.data_ptr<float>(), t_max.data_ptr<float>(),
// occupancy grid & contraction // occupancy grid & contraction
roi.data_ptr<float>(), roi.data_ptr<float>(),
grid_nlvl,
grid_res, grid_res,
grid_binary.data_ptr<bool>(), grid_binary.data_ptr<bool>(),
type, type,
...@@ -272,6 +324,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -272,6 +324,7 @@ std::vector<torch::Tensor> ray_marching(
t_max.data_ptr<float>(), t_max.data_ptr<float>(),
// occupancy grid & contraction // occupancy grid & contraction
roi.data_ptr<float>(), roi.data_ptr<float>(),
grid_nlvl,
grid_res, grid_res,
grid_binary.data_ptr<bool>(), grid_binary.data_ptr<bool>(),
type, type,
...@@ -299,6 +352,7 @@ __global__ void query_occ_kernel( ...@@ -299,6 +352,7 @@ __global__ void query_occ_kernel(
const float *samples, // shape (n_samples, 3) const float *samples, // shape (n_samples, 3)
// occupancy grid & contraction // occupancy grid & contraction
const float *roi, const float *roi,
const int grid_nlvl,
const int3 grid_res, const int3 grid_res,
const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z) const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z)
const ContractionType type, const ContractionType type,
...@@ -314,8 +368,9 @@ __global__ void query_occ_kernel( ...@@ -314,8 +368,9 @@ __global__ void query_occ_kernel(
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]); const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]); const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]); const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
const int mip = mip_level(xyz, roi_min, roi_max, type);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_value); *occs = grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_value);
return; return;
} }
...@@ -330,8 +385,9 @@ torch::Tensor grid_query( ...@@ -330,8 +385,9 @@ torch::Tensor grid_query(
CHECK_INPUT(samples); CHECK_INPUT(samples);
const int n_samples = samples.size(0); const int n_samples = samples.size(0);
const int grid_nlvl = grid_value.size(0);
const int3 grid_res = make_int3( const int3 grid_res = make_int3(
grid_value.size(0), grid_value.size(1), grid_value.size(2)); grid_value.size(1), grid_value.size(2), grid_value.size(3));
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
...@@ -348,6 +404,7 @@ torch::Tensor grid_query( ...@@ -348,6 +404,7 @@ torch::Tensor grid_query(
samples.data_ptr<float>(), samples.data_ptr<float>(),
// grid // grid
roi.data_ptr<float>(), roi.data_ptr<float>(),
grid_nlvl,
grid_res, grid_res,
grid_value.data_ptr<scalar_t>(), grid_value.data_ptr<scalar_t>(),
type, type,
......
...@@ -28,7 +28,7 @@ def query_grid( ...@@ -28,7 +28,7 @@ def query_grid(
samples: (n_samples, 3) tensor of coordinates. samples: (n_samples, 3) tensor of coordinates.
grid_roi: (6,) region of interest of the grid. Usually it should be grid_roi: (6,) region of interest of the grid. Usually it should be
accquired from the grid itself using `grid.roi_aabb`. accquired from the grid itself using `grid.roi_aabb`.
grid_values: A 3D tensor of grid values in the shape of (resx, resy, resz). grid_values: A 4D tensor of grid values in the shape of (nlvl, resx, resy, resz).
grid_type: Contraction type of the grid. Usually it should be grid_type: Contraction type of the grid. Usually it should be
accquired from the grid itself using `grid.contraction_type`. accquired from the grid itself using `grid.contraction_type`.
...@@ -37,7 +37,7 @@ def query_grid( ...@@ -37,7 +37,7 @@ def query_grid(
""" """
assert samples.dim() == 2 and samples.size(-1) == 3 assert samples.dim() == 2 and samples.size(-1) == 3
assert grid_roi.dim() == 1 and grid_roi.size(0) == 6 assert grid_roi.dim() == 1 and grid_roi.size(0) == 6
assert grid_values.dim() == 3 assert grid_values.dim() == 4
assert isinstance(grid_type, ContractionType) assert isinstance(grid_type, ContractionType)
return _C.grid_query( return _C.grid_query(
samples.contiguous(), samples.contiguous(),
...@@ -58,7 +58,7 @@ class Grid(nn.Module): ...@@ -58,7 +58,7 @@ class Grid(nn.Module):
To work with :func:`nerfacc.ray_marching`, three attributes must exist: To work with :func:`nerfacc.ray_marching`, three attributes must exist:
- :attr:`roi_aabb`: The axis-aligned bounding box of the region of interest. - :attr:`roi_aabb`: The axis-aligned bounding box of the region of interest.
- :attr:`binary`: A 3D binarized tensor of shape {resx, resy, resz}, \ - :attr:`binary`: A 4D binarized tensor of shape {nlvl, resx, resy, resz}, \
with torch.bool data type. with torch.bool data type.
- :attr:`contraction_type`: The contraction type of the grid, indicating how \ - :attr:`contraction_type`: The contraction type of the grid, indicating how \
the 3D space is mapped to the grid. the 3D space is mapped to the grid.
...@@ -85,9 +85,9 @@ class Grid(nn.Module): ...@@ -85,9 +85,9 @@ class Grid(nn.Module):
@property @property
def binary(self) -> torch.Tensor: def binary(self) -> torch.Tensor:
"""A 3D binarized tensor with torch.bool data type. """A 4-dim binarized tensor with torch.bool data type.
The tensor is of shape (resx, resy, resz), in which each boolen value The tensor is of shape (nlvl, resx, resy, resz), in which each boolen value
represents whether the corresponding voxel should be kept or not. represents whether the corresponding voxel should be kept or not.
""" """
if hasattr(self, "_binary"): if hasattr(self, "_binary"):
...@@ -120,6 +120,7 @@ class OccupancyGrid(Grid): ...@@ -120,6 +120,7 @@ class OccupancyGrid(Grid):
be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128. be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
contraction_type: The contraction type of the grid. See :class:`nerfacc.ContractionType` contraction_type: The contraction type of the grid. See :class:`nerfacc.ContractionType`
for more details. Default: :attr:`nerfacc.ContractionType.AABB`. for more details. Default: :attr:`nerfacc.ContractionType.AABB`.
levels: The number of levels of the grid. Default: 1.
""" """
NUM_DIM: int = 3 NUM_DIM: int = 3
...@@ -129,6 +130,7 @@ class OccupancyGrid(Grid): ...@@ -129,6 +130,7 @@ class OccupancyGrid(Grid):
roi_aabb: Union[List[int], torch.Tensor], roi_aabb: Union[List[int], torch.Tensor],
resolution: Union[int, List[int], torch.Tensor] = 128, resolution: Union[int, List[int], torch.Tensor] = 128,
contraction_type: ContractionType = ContractionType.AABB, contraction_type: ContractionType = ContractionType.AABB,
levels: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
if isinstance(resolution, int): if isinstance(resolution, int):
...@@ -151,47 +153,59 @@ class OccupancyGrid(Grid): ...@@ -151,47 +153,59 @@ class OccupancyGrid(Grid):
[self.NUM_DIM * 2] [self.NUM_DIM * 2]
), f"Invalid shape: {roi_aabb.shape}" ), f"Invalid shape: {roi_aabb.shape}"
if levels > 1:
assert (
contraction_type == ContractionType.AABB
), "For multi-res occupancy grid, contraction is not supported yet."
# total number of voxels # total number of voxels
self.num_cells = int(resolution.prod().item()) self.num_cells_per_lvl = int(resolution.prod().item())
self.levels = levels
# required attributes # required attributes
self.register_buffer("_roi_aabb", roi_aabb) self.register_buffer("_roi_aabb", roi_aabb)
self.register_buffer( self.register_buffer(
"_binary", torch.zeros(resolution.tolist(), dtype=torch.bool) "_binary",
torch.zeros([levels] + resolution.tolist(), dtype=torch.bool),
) )
self._contraction_type = contraction_type self._contraction_type = contraction_type
# helper attributes # helper attributes
self.register_buffer("resolution", resolution) self.register_buffer("resolution", resolution)
self.register_buffer("occs", torch.zeros(self.num_cells)) self.register_buffer(
"occs", torch.zeros(self.levels * self.num_cells_per_lvl)
)
# Grid coords & indices # Grid coords & indices
grid_coords = _meshgrid3d(resolution).reshape( grid_coords = _meshgrid3d(resolution).reshape(
self.num_cells, self.NUM_DIM self.num_cells_per_lvl, self.NUM_DIM
) )
self.register_buffer("grid_coords", grid_coords, persistent=False) self.register_buffer("grid_coords", grid_coords, persistent=False)
grid_indices = torch.arange(self.num_cells) grid_indices = torch.arange(self.num_cells_per_lvl)
self.register_buffer("grid_indices", grid_indices, persistent=False) self.register_buffer("grid_indices", grid_indices, persistent=False)
@torch.no_grad() @torch.no_grad()
def _get_all_cells(self) -> torch.Tensor: def _get_all_cells(self) -> List[torch.Tensor]:
"""Returns all cells of the grid.""" """Returns all cells of the grid."""
return self.grid_indices return [self.grid_indices] * self.levels
@torch.no_grad() @torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor: def _sample_uniform_and_occupied_cells(self, n: int) -> List[torch.Tensor]:
"""Samples both n uniform and occupied cells.""" """Samples both n uniform and occupied cells."""
uniform_indices = torch.randint( lvl_indices = []
self.num_cells, (n,), device=self.device for lvl in range(self.levels):
) uniform_indices = torch.randint(
occupied_indices = torch.nonzero(self._binary.flatten())[:, 0] self.num_cells_per_lvl, (n,), device=self.device
if n < len(occupied_indices):
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
) )
occupied_indices = occupied_indices[selector] occupied_indices = torch.nonzero(self._binary[lvl].flatten())[:, 0]
indices = torch.cat([uniform_indices, occupied_indices], dim=0) if n < len(occupied_indices):
return indices selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad() @torch.no_grad()
def _update( def _update(
...@@ -205,35 +219,38 @@ class OccupancyGrid(Grid): ...@@ -205,35 +219,38 @@ class OccupancyGrid(Grid):
"""Update the occ field in the EMA way.""" """Update the occ field in the EMA way."""
# sample cells # sample cells
if step < warmup_steps: if step < warmup_steps:
indices = self._get_all_cells() lvl_indices = self._get_all_cells()
else: else:
N = self.num_cells // 4 N = self.num_cells_per_lvl // 4
indices = self._sample_uniform_and_occupied_cells(N) lvl_indices = self._sample_uniform_and_occupied_cells(N)
# infer occupancy: density * step_size for lvl, indices in enumerate(lvl_indices):
grid_coords = self.grid_coords[indices] # infer occupancy: density * step_size
x = ( grid_coords = self.grid_coords[indices]
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32) x = (
) / self.resolution grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE: ) / self.resolution
# only the points inside the sphere are valid if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE:
mask = (x - 0.5).norm(dim=1) < 0.5 # only the points inside the sphere are valid
x = x[mask] mask = (x - 0.5).norm(dim=1) < 0.5
indices = indices[mask] x = x[mask]
# voxel coordinates [0, 1]^3 -> world indices = indices[mask]
x = contract_inv( # voxel coordinates [0, 1]^3 -> world
x, x = contract_inv(
roi=self._roi_aabb, (x - 0.5) * (2**lvl) + 0.5,
type=self._contraction_type, roi=self._roi_aabb,
) type=self._contraction_type,
occ = occ_eval_fn(x).squeeze(-1) )
occ = occ_eval_fn(x).squeeze(-1)
# ema update # ema update
self.occs[indices] = torch.maximum(self.occs[indices] * ema_decay, occ) cell_ids = lvl * self.num_cells_per_lvl + indices
# suppose to use scatter max but emperically it is almost the same. self.occs[cell_ids] = torch.maximum(
# self.occs, _ = scatter_max( self.occs[cell_ids] * ema_decay, occ
# occ, indices, dim=0, out=self.occs * ema_decay )
# ) # suppose to use scatter max but emperically it is almost the same.
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self._binary = ( self._binary = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre) self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self._binary.shape) ).view(self._binary.shape)
......
from torch import Tensor
from .pack import unpack_data
def distortion(
packed_info: Tensor, weights: Tensor, t_starts: Tensor, t_ends: Tensor
) -> Tensor:
"""Distortion loss from Mip-NeRF 360 paper, Equ. 15.
Args:
packed_info: Packed info for the samples. (n_rays, 2)
weights: Weights for the samples. (all_samples,)
t_starts: Per-sample start distance. Tensor with shape (all_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (all_samples, 1).
Returns:
Distortion loss. (n_rays,)
"""
# (all_samples, 1) -> (n_rays, n_samples)
w = unpack_data(packed_info, weights[..., None]).squeeze(-1)
t1 = unpack_data(packed_info, t_starts).squeeze(-1)
t2 = unpack_data(packed_info, t_ends).squeeze(-1)
interval = t2 - t1
tmid = (t1 + t2) / 2
loss_uni = (1 / 3) * (interval * w.pow(2)).sum(-1)
ww = w.unsqueeze(-1) * w.unsqueeze(-2)
mm = (tmid.unsqueeze(-1) - tmid.unsqueeze(-2)).abs()
loss_bi = (ww * mm).sum((-1, -2))
return loss_uni + loss_bi
...@@ -125,6 +125,7 @@ def unpack_data( ...@@ -125,6 +125,7 @@ def unpack_data(
packed_info: Tensor, packed_info: Tensor,
data: Tensor, data: Tensor,
n_samples: Optional[int] = None, n_samples: Optional[int] = None,
pad_value: float = 0.0,
) -> Tensor: ) -> Tensor:
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D). """Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D).
...@@ -134,6 +135,7 @@ def unpack_data( ...@@ -134,6 +135,7 @@ def unpack_data(
data: Packed data to unpack. Tensor with shape (n_samples, D). data: Packed data to unpack. Tensor with shape (n_samples, D).
n_samples (int): Optional Number of samples per ray. If not provided, it \ n_samples (int): Optional Number of samples per ray. If not provided, it \
will be inferred from the packed_info. will be inferred from the packed_info.
pad_value: Value to pad the unpacked data.
Returns: Returns:
Unpacked data (n_rays, n_samples, D). Unpacked data (n_rays, n_samples, D).
...@@ -164,21 +166,27 @@ def unpack_data( ...@@ -164,21 +166,27 @@ def unpack_data(
), "data must be a 2D tensor with shape (n_samples, D)." ), "data must be a 2D tensor with shape (n_samples, D)."
if n_samples is None: if n_samples is None:
n_samples = packed_info[:, 1].max().item() n_samples = packed_info[:, 1].max().item()
return _UnpackData.apply(packed_info, data, n_samples) return _UnpackData.apply(packed_info, data, n_samples, pad_value)
class _UnpackData(torch.autograd.Function): class _UnpackData(torch.autograd.Function):
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D).""" """Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D)."""
@staticmethod @staticmethod
def forward(ctx, packed_info: Tensor, data: Tensor, n_samples: int): def forward(
ctx,
packed_info: Tensor,
data: Tensor,
n_samples: int,
pad_value: float = 0.0,
) -> Tensor:
# shape of the data should be (all_samples, D) # shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous() packed_info = packed_info.contiguous()
data = data.contiguous() data = data.contiguous()
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info) ctx.save_for_backward(packed_info)
ctx.n_samples = n_samples ctx.n_samples = n_samples
return _C.unpack_data(packed_info, data, n_samples) return _C.unpack_data(packed_info, data, n_samples, pad_value)
@staticmethod @staticmethod
def backward(ctx, grad: Tensor): def backward(ctx, grad: Tensor):
...@@ -187,4 +195,4 @@ class _UnpackData(torch.autograd.Function): ...@@ -187,4 +195,4 @@ class _UnpackData(torch.autograd.Function):
n_samples = ctx.n_samples n_samples = ctx.n_samples
mask = _C.unpack_info_to_mask(packed_info, n_samples) mask = _C.unpack_info_to_mask(packed_info, n_samples)
packed_grad = grad[mask].contiguous() packed_grad = grad[mask].contiguous()
return None, packed_grad, None return None, packed_grad, None, None
from typing import Optional
import torch
from torch import Tensor
import nerfacc.cuda as _C
class PDFOuter(torch.autograd.Function):
@staticmethod
def forward(
ctx,
ts: Tensor,
weights: Tensor,
masks: Optional[Tensor],
ts_query: Tensor,
masks_query: Optional[Tensor],
):
assert ts.dim() == weights.dim() == ts_query.dim() == 2
assert ts.shape[0] == weights.shape[0] == ts_query.shape[0]
assert ts.shape[1] == weights.shape[1] + 1
ts = ts.contiguous()
weights = weights.contiguous()
ts_query = ts_query.contiguous()
masks = masks.contiguous() if masks is not None else None
masks_query = (
masks_query.contiguous() if masks_query is not None else None
)
weights_query = _C.pdf_readout(
ts, weights, masks, ts_query, masks_query
)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(ts, masks, ts_query, masks_query)
return weights_query
@staticmethod
def backward(ctx, weights_query_grads: Tensor):
weights_query_grads = weights_query_grads.contiguous()
ts, masks, ts_query, masks_query = ctx.saved_tensors
weights_grads = _C.pdf_readout(
ts_query, weights_query_grads, masks_query, ts, masks
)
return None, weights_grads, None, None, None
pdf_outer = PDFOuter.apply
@torch.no_grad()
def pdf_sampling(
t: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
padding: float = 0.01,
stratified: bool = False,
single_jitter: bool = False,
masks: Optional[torch.Tensor] = None,
):
assert t.shape[0] == weights.shape[0]
assert t.shape[1] == weights.shape[1] + 1
if masks is not None:
assert t.shape[0] == masks.shape[0]
t_new = _C.pdf_sampling(
t.contiguous(),
weights.contiguous(),
n_samples + 1, # be careful here!
padding,
stratified,
single_jitter,
masks.contiguous() if masks is not None else None,
)
return t_new # [n_ray, n_samples+1]
#!/usr/bin/env python3
#
# File : prop_utils.py
# Author : Hang Gao
# Email : hangg.sv7@gmail.com
# Date : 02/19/2023
#
# Distributed under terms of the MIT license.
from typing import Callable, Literal, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from .intersection import ray_aabb_intersect
from .pdf import pdf_outer, pdf_sampling
def sample_from_weighted(
bins: torch.Tensor,
weights: torch.Tensor,
num_samples: int,
stratified: bool = False,
vmin: float = -torch.inf,
vmax: float = torch.inf,
) -> torch.Tensor:
"""
Args:
bins: (..., B + 1).
weights: (..., B).
Returns:
samples: (..., S + 1).
"""
B = weights.shape[-1]
S = num_samples
assert bins.shape[-1] == B + 1
dtype, device = bins.dtype, bins.device
eps = torch.finfo(weights.dtype).eps
# (..., B).
pdf = F.normalize(weights, p=1, dim=-1)
# (..., B + 1).
cdf = torch.cat(
[
torch.zeros_like(pdf[..., :1]),
torch.cumsum(pdf[..., :-1], dim=-1),
torch.ones_like(pdf[..., :1]),
],
dim=-1,
)
# (..., S). Sample positions between [0, 1).
if not stratified:
pad = 1 / (2 * S)
# Get the center of each pdf bins.
u = torch.linspace(pad, 1 - pad - eps, S, dtype=dtype, device=device)
u = u.broadcast_to(bins.shape[:-1] + (S,))
else:
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
u_max = eps + (1 - eps) / S
max_jitter = (1 - u_max) / (S - 1) - eps
# Only perform one jittering per ray (`single_jitter` in the original
# implementation.)
u = (
torch.linspace(0, 1 - u_max, S, dtype=dtype, device=device)
+ torch.rand(
*bins.shape[:-1],
1,
dtype=dtype,
device=device,
)
* max_jitter
)
# (..., S).
ceil = torch.searchsorted(cdf.contiguous(), u.contiguous(), side="right")
floor = ceil - 1
# (..., S * 2).
inds = torch.cat([floor, ceil], dim=-1)
# (..., S).
cdf0, cdf1 = cdf.gather(-1, inds).split(S, dim=-1)
b0, b1 = bins.gather(-1, inds).split(S, dim=-1)
# (..., S). Linear interpolation in 1D.
t = (u - cdf0) / torch.clamp(cdf1 - cdf0, min=eps)
# Sample centers.
centers = b0 + t * (b1 - b0)
samples = (centers[..., 1:] + centers[..., :-1]) / 2
samples = torch.cat(
[
(2 * centers[..., :1] - samples[..., :1]).clamp_min(vmin),
samples,
(2 * centers[..., -1:] - samples[..., -1:]).clamp_max(vmax),
],
dim=-1,
)
return samples
def render_weight_from_density(
sigmas: torch.Tensor,
t_starts: torch.Tensor,
t_ends: torch.Tensor,
opaque_bkgd: bool = False,
) -> torch.Tensor:
"""
Args:
sigmas: (..., S, 1).
t_starts: (..., S).
t_ends: (..., S).
Return:
weights: (..., S).
"""
# (..., S).
deltas = t_ends - t_starts
# (..., S).
sigma_deltas = sigmas[..., 0] * deltas
if opaque_bkgd:
sigma_deltas = torch.cat(
[
sigma_deltas[..., :-1],
torch.full_like(sigma_deltas[..., -1:], torch.inf),
],
dim=-1,
)
alphas = 1 - torch.exp(-sigma_deltas)
trans = torch.exp(
-(
torch.cat(
[
torch.zeros_like(sigma_deltas[..., :1]),
torch.cumsum(sigma_deltas[..., :-1], dim=-1),
],
dim=-1,
)
)
)
weights = alphas * trans
return weights
def render_from_weighted(
rgbs: torch.Tensor,
t_vals: torch.Tensor,
weights: torch.Tensor,
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
rgbs: (..., S, 3).
t_vals: (..., S + 1, 1).
weights: (..., S, 1).
Return:
colors: (..., 3).
opacities: (..., 3).
depths: (..., 1). The naming is a bit confusing since it is actually
the expected marching *distances*.
"""
# Use white instead of black background by default.
render_bkgd = (
render_bkgd
if render_bkgd is not None
else torch.ones(3, dtype=rgbs.dtype, device=rgbs.device)
)
eps = torch.finfo(rgbs.dtype).eps
# (..., 1).
opacities = weights.sum(axis=-2)
# (..., 1).
bkgd_weights = (1 - opacities).clamp_min(0)
# (..., 3).
colors = (weights * rgbs).sum(dim=-2) + bkgd_weights * render_bkgd
# (..., S, 1).
t_mids = (t_vals[..., 1:, :] + t_vals[..., :-1, :]) / 2
depths = (weights * t_mids).sum(dim=-2) / opacities.clamp_min(eps)
return colors, opacities, depths
def transform_stot(
transform_type: Literal["uniform", "lindisp"],
s_vals: torch.Tensor,
t_min: torch.Tensor,
t_max: torch.Tensor,
) -> torch.Tensor:
if transform_type == "uniform":
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
elif transform_type == "lindisp":
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
else:
raise ValueError(f"Unknown transform_type: {transform_type}")
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
return icontract_fn(s_vals)
def rendering(
# radiance field
rgb_sigma_fn: Callable,
num_samples: int,
# proposals
prop_sigma_fns: Sequence[Callable],
num_samples_per_prop: Sequence[int],
# rays
rays_o: torch.Tensor,
rays_d: torch.Tensor,
t_min: Optional[torch.Tensor] = None,
t_max: Optional[torch.Tensor] = None,
# bounding box of the scene
scene_aabb: Optional[torch.Tensor] = None,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
stratified: bool = False,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
opaque_bkgd: bool = False,
render_bkgd: Optional[torch.Tensor] = None,
# gradient options
proposal_requires_grad: bool = False,
proposal_annealing: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if len(prop_sigma_fns) != len(num_samples_per_prop):
raise ValueError(
"`sigma_fns` and `samples_per_level` must have the same length."
)
if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
if t_min is None or t_max is None:
if scene_aabb is not None:
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
else:
t_min = torch.zeros_like(rays_o[..., 0])
t_max = torch.ones_like(rays_o[..., 0]) * 1e10
if near_plane is not None:
t_min = torch.clamp(t_min, min=near_plane)
t_max = torch.clamp(t_max, min=near_plane)
if far_plane is not None:
t_min = torch.clamp(t_min, max=far_plane)
t_max = torch.clamp(t_max, max=far_plane)
s_vals = torch.cat(
[
torch.zeros_like(rays_o[..., :1]),
torch.ones_like(rays_o[..., :1]),
],
dim=-1,
)
weights = torch.ones_like(rays_o[..., :1])
rgbs = t_vals = None
weights_per_level, s_vals_per_level = [], []
for level, (level_fn, level_samples) in enumerate(
zip(
prop_sigma_fns + [rgb_sigma_fn],
num_samples_per_prop + [num_samples],
)
):
is_prop = level < len(prop_sigma_fns)
annealed_weights = torch.pow(weights, proposal_annealing)
# (N, S + 1).
s_vals = sample_from_weighted(
s_vals,
annealed_weights,
level_samples,
stratified=stratified,
vmin=0.0,
vmax=1.0,
).detach()
# s_vals = pdf_sampling(
# s_vals,
# annealed_weights,
# level_samples,
# padding=0.0,
# stratified=stratified,
# ).detach()
t_vals = transform_stot(
sampling_type, s_vals, t_min[..., None], t_max[..., None] # type: ignore
)
if is_prop:
with torch.set_grad_enabled(proposal_requires_grad):
# (N, S, 1).
sigmas = level_fn(t_vals[..., :-1, None], t_vals[..., 1:, None])
else:
# (N, S, *).
rgbs, sigmas = level_fn(
t_vals[..., :-1, None], t_vals[..., 1:, None]
)
# (N, S).
weights = render_weight_from_density(
sigmas,
t_vals[..., :-1],
t_vals[..., 1:],
opaque_bkgd=opaque_bkgd,
)
weights_per_level.append(weights)
s_vals_per_level.append(s_vals)
assert rgbs is not None and t_vals is not None
rgbs, opacities, depths = render_from_weighted(
rgbs, t_vals[..., None], weights[..., None], render_bkgd
)
return (
rgbs,
opacities,
depths,
(weights_per_level, s_vals_per_level),
)
def _outer(
t0_starts: torch.Tensor,
t0_ends: torch.Tensor,
t1_starts: torch.Tensor,
t1_ends: torch.Tensor,
y1: torch.Tensor,
) -> torch.Tensor:
"""
Args:
t0_starts: (..., S0).
t0_ends: (..., S0).
t1_starts: (..., S1).
t1_ends: (..., S1).
y1: (..., S1).
"""
cy1 = torch.cat(
[torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
)
idx_lo = (
torch.searchsorted(
t1_starts.contiguous(), t0_starts.contiguous(), side="right"
)
- 1
)
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(
t1_ends.contiguous(), t0_ends.contiguous(), side="right"
)
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def _lossfun_outer(
t: torch.Tensor, w: torch.Tensor, t_env: torch.Tensor, w_env: torch.Tensor
):
"""
Args:
t: interval edges, (..., S + 1).
w: weights, (..., S).
t_env: interval edges of the upper bound enveloping historgram, (..., S + 1).
w_env: weights that should upper bound the inner (t,w) histogram, (..., S).
"""
eps = 1e-7 # torch.finfo(t.dtype).eps
w_outer = pdf_outer(t_env, w_env, None, t, None)
# w_outer = _outer(
# t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env
# )
return torch.clip(w - w_outer, min=0) ** 2 / (w + eps)
def compute_prop_loss(
s_vals_per_level: Sequence[torch.Tensor],
weights_per_level: Sequence[torch.Tensor],
) -> torch.Tensor:
c = s_vals_per_level[-1].detach()
w = weights_per_level[-1].detach()
loss = 0.0
for svals, weights in zip(s_vals_per_level[:-1], weights_per_level[:-1]):
loss += torch.mean(_lossfun_outer(c, w, svals, weights))
return loss
def get_proposal_requires_grad_fn(
target: float = 5.0, num_steps: int = 1000
) -> Callable:
schedule = lambda s: min(s / num_steps, 1.0) * target
steps_since_last_grad = 0
def proposal_requires_grad_fn(step: int) -> bool:
nonlocal steps_since_last_grad
target_steps_since_last_grad = schedule(step)
requires_grad = steps_since_last_grad > target_steps_since_last_grad
if requires_grad:
steps_since_last_grad = 0
steps_since_last_grad += 1
return requires_grad
return proposal_requires_grad_fn
def get_proposal_annealing_fn(
slop: float = 10.0, num_steps: int = 1000
) -> Callable:
def proposal_annealing_fn(step: int) -> float:
# https://arxiv.org/pdf/2111.12077.pdf eq. 18
train_frac = max(min(float(step) / num_steps, 0), 1)
bias = lambda x, b: (b * x) / ((b - 1) * x + 1)
anneal = bias(train_frac, slop)
return anneal
return proposal_annealing_fn
...@@ -169,7 +169,7 @@ def ray_marching( ...@@ -169,7 +169,7 @@ def ray_marching(
device=rays_o.device, device=rays_o.device,
) )
grid_binary = torch.ones( grid_binary = torch.ones(
[1, 1, 1], dtype=torch.bool, device=rays_o.device [1, 1, 1, 1], dtype=torch.bool, device=rays_o.device
) )
contraction_type = ContractionType.AABB.to_cpp_version() contraction_type = ContractionType.AABB.to_cpp_version()
...@@ -190,7 +190,9 @@ def ray_marching( ...@@ -190,7 +190,9 @@ def ray_marching(
) )
# skip invisible space # skip invisible space
if sigma_fn is not None or alpha_fn is not None: if (alpha_thre > 0.0 or early_stop_eps > 0.0) and (
sigma_fn is not None or alpha_fn is not None
):
# Query sigma without gradients # Query sigma without gradients
if sigma_fn is not None: if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices) sigmas = sigma_fn(t_starts, t_ends, ray_indices)
...@@ -204,6 +206,9 @@ def ray_marching( ...@@ -204,6 +206,9 @@ def ray_marching(
alphas.shape == t_starts.shape alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape) ), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
if grid is not None:
alpha_thre = min(alpha_thre, grid.occs.mean().item())
# Compute visibility of the samples, and filter out invisible samples # Compute visibility of the samples, and filter out invisible samples
masks = render_visibility( masks = render_visibility(
alphas, alphas,
......
import math
from typing import Callable, Optional, Tuple, Union, overload
import torch
import nerfacc.cuda as _C
from .cdf import ray_resampling
from .grid import Grid
from .pack import pack_info, unpack_info
from .vol_rendering import (
render_transmittance_from_alpha,
render_weight_from_density,
)
@overload
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: torch.Tensor, # [n_rays,]
t_max: torch.Tensor, # [n_rays,]
step_size: float,
cone_angle: float = 0.0,
grid: Optional[Grid] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample along rays with per-ray min max."""
...
@overload
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: float,
t_max: float,
step_size: float,
cone_angle: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample along rays with near far plane."""
...
@torch.no_grad()
def sample_along_rays(
rays_o: torch.Tensor, # [n_rays, 3]
rays_d: torch.Tensor, # [n_rays, 3]
t_min: Union[float, torch.Tensor], # [n_rays,]
t_max: Union[float, torch.Tensor], # [n_rays,]
step_size: float,
cone_angle: float = 0.0,
grid: Optional[Grid] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample intervals along rays."""
if isinstance(t_min, float) and isinstance(t_max, float):
n_rays = rays_o.shape[0]
device = rays_o.device
num_steps = math.floor((t_max - t_min) / step_size)
t_starts = (
(t_min + torch.arange(0, num_steps, device=device) * step_size)
.expand(n_rays, -1)
.reshape(-1, 1)
)
t_ends = t_starts + step_size
ray_indices = torch.arange(0, n_rays, device=device).repeat_interleave(
num_steps, dim=0
)
else:
if grid is None:
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
t_min.contiguous(),
t_max.contiguous(),
# sampling
step_size,
cone_angle,
)
else:
(
packed_info,
ray_indices,
t_starts,
t_ends,
) = _C.ray_marching_with_grid(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# coontraction and grid
grid.roi_aabb.contiguous(),
grid.binary.contiguous(),
grid.contraction_type.to_cpp_version(),
# sampling
step_size,
cone_angle,
)
return ray_indices, t_starts, t_ends
@torch.no_grad()
def proposal_sampling_with_filter(
t_starts: torch.Tensor, # [n_samples, 1]
t_ends: torch.Tensor, # [n_samples, 1]
ray_indices: torch.Tensor, # [n_samples,]
n_rays: Optional[int] = None,
# compute density of samples: {t_starts, t_ends, ray_indices} -> density
sigma_fn: Optional[Callable] = None,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns: Tuple[Callable, ...] = [],
proposal_n_samples: Tuple[int, ...] = [],
proposal_require_grads: bool = False,
# acceleration options
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Hueristic marching with proposal fns."""
assert len(proposal_sigma_fns) == len(proposal_n_samples), (
"proposal_sigma_fns and proposal_n_samples must have the same length, "
f"but got {len(proposal_sigma_fns)} and {len(proposal_n_samples)}."
)
if n_rays is None:
n_rays = ray_indices.max() + 1
# compute density from proposal fns
proposal_samples = []
for proposal_fn, n_samples in zip(proposal_sigma_fns, proposal_n_samples):
# compute weights for resampling
sigmas = proposal_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
transmittance = render_transmittance_from_alpha(
alphas, ray_indices=ray_indices, n_rays=n_rays
)
weights = alphas * transmittance
# Compute visibility for filtering
if alpha_thre > 0 or early_stop_eps > 0:
vis = (alphas >= alpha_thre) & (transmittance >= early_stop_eps)
vis = vis.squeeze(-1)
ray_indices, t_starts, t_ends, weights = (
ray_indices[vis],
t_starts[vis],
t_ends[vis],
weights[vis],
)
packed_info = pack_info(ray_indices, n_rays=n_rays)
# Rerun the proposal function **with** gradients on filtered samples.
if proposal_require_grads:
with torch.enable_grad():
sigmas = proposal_fn(t_starts, t_ends, ray_indices)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
proposal_samples.append(
(packed_info, t_starts, t_ends, weights)
)
# resampling on filtered samples
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=n_samples
)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# last round filtering with sigma_fn
if (alpha_thre > 0 or early_stop_eps > 0) and (sigma_fn is not None):
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
transmittance = render_transmittance_from_alpha(
alphas, ray_indices=ray_indices, n_rays=n_rays
)
vis = (alphas >= alpha_thre) & (transmittance >= early_stop_eps)
vis = vis.squeeze(-1)
ray_indices, t_starts, t_ends = (
ray_indices[vis],
t_starts[vis],
t_ends[vis],
)
return ray_indices, t_starts, t_ends, proposal_samples
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
__version__ = "0.3.5" __version__ = "0.4.0"
...@@ -18,7 +18,7 @@ def test_occ_grid(): ...@@ -18,7 +18,7 @@ def test_occ_grid():
occ_grid = OccupancyGrid(roi_aabb=roi_aabb, resolution=128).to(device) occ_grid = OccupancyGrid(roi_aabb=roi_aabb, resolution=128).to(device)
occ_grid.every_n_step(0, occ_eval_fn, occ_thre=0.1) occ_grid.every_n_step(0, occ_eval_fn, occ_thre=0.1)
assert occ_grid.roi_aabb.shape == (6,) assert occ_grid.roi_aabb.shape == (6,)
assert occ_grid.binary.shape == (128, 128, 128) assert occ_grid.binary.shape == (1, 128, 128, 128)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
......
import pytest
import torch
from nerfacc import pack_info, ray_marching
from nerfacc.losses import distortion
device = "cuda:0"
batch_size = 32
eps = 1e-6
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_distortion():
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
ray_indices, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=1e-3,
)
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device)
loss = distortion(packed_info, weights, t_starts, t_ends)
assert loss.shape == (batch_size,)
if __name__ == "__main__":
test_distortion()
import pytest
import torch
from nerfacc import pack_info, ray_marching, ray_resampling
device = "cuda:0"
batch_size = 128
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
ray_indices, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=1e-3,
)
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=32
)
assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
if __name__ == "__main__":
test_resampling()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment