Unverified Commit 8340e19d authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

0.5.0: Rewrite all the underlying CUDA. Speedup and Benchmarking. (#182)

* importance_sampling with test

* package importance_sampling

* compute_intervals tested and packaged

* compute_intervals_v2

* bicycle is failing

* fix cut in compute_intervals_v2, test pass for rendering

* hacky way to get opaque_bkgd work

* reorg ING

* PackedRaySegmentsSpec

* chunk_ids -> ray_ids

* binary -> occupied

* test_traverse_grid_basic checked

* fix traverse_grid with step size, checked

* support max_step_size, not verified

* _cuda and cuda; upgrade ray_marching

* inclusive scan

* test_exclusive_sum but seems to have numeric error

* inclusive_sum_backward verified

* exclusive sum backward

* merge fwd and bwd for scan

* inclusive & exclusive prod verified

* support normal scan with torch funcs

* rendering and tests

* a bit clean up

* importance_sampling verified

* stratified for importance_sampling

* importance_sampling in pdf.py

* RaySegmentsSpec in data_specs; fix various bugs

* verified with _proposal_packed.py

* importance sampling support batch input/output. need to verify

* prop script with batch samples

* try to use cumsum  instead of cumprod

* searchsorted

* benchmarking prop

* ray_aabb_intersect untested

* update prop benchmark numbers

* minor fixes

* batched ray_aabb_intersect

* ray_aabb_intersect and traverse with grid(s)

* tiny optimize for traverse_grids kernels

* traverse_grids return intervals and samples

* cub not verified

* cleanup

* propnet and occgrid as estimators

* training print iters 10k

* prop is good now

* benchmark in google sheet.

* really cleanup: scan.py and test

* pack.py and test

* rendering and test

* data_specs.py and pdf.py docs

* data_specs.py and pdf.py docs

* init and headers

* grid.py and test for it

* occ grid docs

* generated docs

* example docs for pack and scan function.

* doc fix for volrend.py

* doc fix for pdf.py

* fix doc for rendering function

* docs

* propnet docs

* update scripts

* docs: index.rst

* methodology docs

* docs for examples

* mlp nerf script

* update t-nerf script

* rename dnerf to tnerf

* misc update

* bug fix: pdf_loss with test

* minor fix

* update readme with submodules

* fix format

* update gitingore file

* fix doc failure. teaser png to jpg

* docs in examples/
parent e547490c
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*
* Modified from
* https://github.com/pytorch/pytorch/blob/06a64f7eaa47ce430a3fa61016010075b59b18a7/aten/src/ATen/native/cuda/ScanUtils.cuh
*/
#include "utils_cuda.cuh"
// CUB support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
namespace {
namespace device {
/* Perform an inclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void inclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size) {
row_tgt[col] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
inclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
inclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
/* Perform an exclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void exclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
row_tgt[0] = init;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size - 1) row_tgt[col1 + 1] = row_buf[threadIdx.x];
if (col2 < row_size - 1) row_tgt[col2 + 1] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size - 1) {
row_tgt[col + 1] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
exclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
exclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
} // namespace device
} // namespace
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *near,
scalar_t *far)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax)
_swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax)
_swap(tymin, tymax);
if (tmin > tymax || tymin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tymin > tmin)
tmin = tymin;
if (tymax < tmax)
tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax)
_swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tzmin > tmin)
tmin = tzmin;
if (tzmax < tmax)
tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void ray_aabb_intersect_kernel(
const int N,
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *t_min,
scalar_t *t_max)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
// locate
rays_o += thread_id * 3;
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
}
/**
* @brief Ray AABB Test
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero. 1e10 is returned if no intersection.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(aabb);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(N, threads);
torch::Tensor t_min = torch::empty({N}, rays_o.options());
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&]
{ ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()); }));
return {t_min, t_max};
}
\ No newline at end of file
// This file contains only Python bindings
#include "include/data_spec.hpp"
#include <torch/extension.h>
bool is_cub_available() {
// FIXME: why return false?
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
// scan
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
// grid
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor aabbs, // [n_aabbs, 6]
const float near_plane,
const float far_plane,
const float miss_value);
std::vector<RaySegmentsSpec> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples);
// pdf
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
torch::Tensor n_intervels_per_ray,
bool stratified);
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
int64_t n_intervels_per_ray,
bool stratified);
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function
_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);
#undef _REG_FUNC
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void unpack_info_kernel(
// input
const int n_rays,
const int *packed_info,
// output
int64_t *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
ray_indices += base;
for (int j = 0; j < steps; ++j)
{
ray_indices[j] = i;
}
}
__global__ void unpack_info_to_mask_kernel(
// input
const int n_rays,
const int *packed_info,
const int n_samples,
// output
bool *masks) // [n_rays, n_samples]
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
masks += i * n_samples;
for (int j = 0; j < steps; ++j)
{
masks[j] = true;
}
}
template <typename scalar_t>
__global__ void unpack_data_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const int data_dim,
const scalar_t *data,
const int n_sampler_per_ray,
scalar_t *unpacked_data) // (n_rays, n_sampler_per_ray, data_dim)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
data += base * data_dim;
unpacked_data += i * n_sampler_per_ray * data_dim;
for (int j = 0; j < steps; j++)
{
for (int k = 0; k < data_dim; k++)
{
unpacked_data[j * data_dim + k] = data[j * data_dim + k];
}
}
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int64_t>());
return ray_indices;
}
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor masks = torch::zeros(
{n_rays, n_samples}, packed_info.options().dtype(torch::kBool));
unpack_info_to_mask_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
n_samples,
masks.data_ptr<bool>());
return masks;
}
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(data);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(data.ndimension() == 2);
const int n_rays = packed_info.size(0);
const int n_samples = data.size(0);
const int data_dim = data.size(1);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::full(
{n_rays, n_samples_per_ray, data_dim}, pad_value, data.options());
AT_DISPATCH_ALL_TYPES(
data.scalar_type(),
"unpack_data",
([&]
{ unpack_data_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
data_dim,
data.data_ptr<scalar_t>(),
n_samples_per_ray,
// outputs
unpacked_data.data_ptr<scalar_t>()); }));
return unpacked_data;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <torch/extension.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
......@@ -11,15 +9,16 @@
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include "include/helpers_cuda.h"
#include "include/data_spec.hpp"
#include "include/data_spec_packed.cuh"
#include "include/utils_cuda.cuh"
#include "include/utils_grid.cuh"
#include "include/utils_math.cuh"
namespace F = torch::nn::functional;
static constexpr uint32_t MAX_GRID_LEVELS = 8;
template <typename scalar_t>
inline __device__ __host__ scalar_t ceil_div(scalar_t a, scalar_t b)
{
return (a + b - 1) / b;
}
namespace {
namespace device {
// Taken from:
// https://github.com/pytorch/pytorch/blob/8f1c3c68d3aba5c8898bfb3144988aab6776d549/aten/src/ATen/native/cuda/Bucketization.cu
......@@ -63,144 +62,251 @@ __device__ int64_t upper_bound(const scalar_t *data_ss, int64_t start, int64_t e
return start;
}
template <typename scalar_t>
__global__ void pdf_sampling_kernel(
at::PhiloxCudaState philox_args,
const int64_t n_samples_in, // n_samples_in or whatever (not used)
const int64_t *info_ts, // nullptr or [n_rays, 2]
const scalar_t *ts, // [n_rays, n_samples_in] or packed [all_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in] or packed [all_samples_in]
const bool *masks, // [n_rays]
const bool stratified,
const bool single_jitter,
// outputs
const int64_t numel,
const int64_t n_samples_out,
scalar_t *ts_out) // [n_rays, n_samples_out]
inline __device__ int32_t binary_search_chunk_id(
const int64_t item_id,
const int32_t n_chunks,
const int64_t *chunk_starts)
{
int64_t n_bins_out = n_samples_out - 1;
scalar_t u_pad, u_interval;
if (stratified) {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
} else {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
int32_t start = 0;
int32_t end = n_chunks;
while (start < end)
{
const int32_t mid = start + ((end - start) >> 1);
const int64_t mid_val = chunk_starts[mid];
if (!(mid_val > item_id)) start = mid + 1;
else end = mid;
}
// = stratified ? 1.0f / n_samples_out : (1.0f - 2 * pad) / (n_samples_out - 1);
return start;
}
/* kernels for importance_sampling */
__global__ void compute_ray_ids_kernel(
const int64_t n_rays,
const int64_t n_items,
const int64_t *chunk_starts,
// outputs
int64_t *ray_indices)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_items; tid += blockDim.x * gridDim.x)
{
int64_t ray_id = tid / n_samples_out;
int64_t sample_id = tid % n_samples_out;
if (masks != nullptr && !masks[ray_id]) {
// This ray is to be skipped.
// Be careful the ts needs to be initialized properly.
continue;
ray_indices[tid] = binary_search_chunk_id(tid, n_rays, chunk_starts) - 1;
}
}
int64_t start_bd, end_bd;
if (info_ts == nullptr)
{
// no packing, the input is already [n_rays, n_samples_in]
start_bd = ray_id * n_samples_in;
end_bd = start_bd + n_samples_in;
}
else
__global__ void importance_sampling_kernel(
// cdfs
PackedRaySegmentsSpec ray_segments,
const float *cdfs,
// jittering
bool stratified,
at::PhiloxCudaState philox_args,
// outputs
PackedRaySegmentsSpec samples)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < samples.n_edges; tid += blockDim.x * gridDim.x)
{
// packed input, the input is [all_samples_in]
start_bd = info_ts[ray_id * 2];
end_bd = start_bd + info_ts[ray_id * 2 + 1];
if (start_bd == end_bd) {
// This ray is empty, so there is nothing to sample from.
// Be careful the ts needs to be initialized properly.
continue;
int32_t ray_id;
int64_t n_samples, sid;
if (samples.is_batched) {
ray_id = tid / samples.n_edges_per_ray;
n_samples = samples.n_edges_per_ray;
sid = tid - ray_id * samples.n_edges_per_ray;
} else {
ray_id = binary_search_chunk_id(tid, samples.n_rays, samples.chunk_starts) - 1;
samples.ray_indices[tid] = ray_id;
n_samples = samples.chunk_cnts[ray_id];
sid = tid - samples.chunk_starts[ray_id];
}
int64_t base, last;
if (ray_segments.is_batched) {
base = ray_id * ray_segments.n_edges_per_ray;
last = base + ray_segments.n_edges_per_ray - 1;
} else {
base = ray_segments.chunk_starts[ray_id];
last = base + ray_segments.chunk_cnts[ray_id] - 1;
}
scalar_t u = u_pad + sample_id * u_interval;
float u_floor = cdfs[base];
float u_ceil = cdfs[last];
if (stratified)
{
float u_step = (u_ceil - u_floor) / n_samples;
float bias = 0.5f;
if (stratified) {
auto seeds = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
int64_t rand_seq_id = single_jitter ? ray_id : tid;
curand_init(std::get<0>(seeds), rand_seq_id, std::get<1>(seeds), &state);
float rand = curand_uniform(&state);
u -= rand * u_interval;
u = max(u, static_cast<scalar_t>(0.0f));
curand_init(std::get<0>(seeds), ray_id, std::get<1>(seeds), &state);
bias = curand_uniform(&state);
}
float u = u_floor + (sid + bias) * u_step;
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos = upper_bound<scalar_t>(accum_weights, start_bd, end_bd, u, nullptr);
// i.e. cdfs[p - 1] <= u < cdfs[p]
int64_t p = upper_bound<float>(cdfs, base, last, u, nullptr);
int64_t p0 = max(min(p - 1, last), base);
int64_t p1 = max(min(p, last), base);
float u_lower = cdfs[p0];
float u_upper = cdfs[p1];
float t_lower = ray_segments.vals[p0];
float t_upper = ray_segments.vals[p1];
float t;
if (u_upper - u_lower < 1e-10f) {
t = (t_lower + t_upper) * 0.5f;
} else {
float scaling = (t_upper - t_lower) / (u_upper - u_lower);
t = (u - u_lower) * scaling + t_lower;
}
samples.vals[tid] = t;
}
}
int64_t p0 = min(max(pos - 1, start_bd), end_bd - 1);
int64_t p1 = min(max(pos, start_bd), end_bd - 1);
__global__ void compute_intervels_kernel(
PackedRaySegmentsSpec ray_segments,
PackedRaySegmentsSpec samples,
// outputs
PackedRaySegmentsSpec intervals)
{
// parallelize over samples
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < samples.n_edges; tid += blockDim.x * gridDim.x)
{
int32_t ray_id;
int64_t n_samples, sid;
if (samples.is_batched) {
ray_id = tid / samples.n_edges_per_ray;
n_samples = samples.n_edges_per_ray;
sid = tid - ray_id * samples.n_edges_per_ray;
} else {
ray_id = samples.ray_indices[tid];
n_samples = samples.chunk_cnts[ray_id];
sid = tid - samples.chunk_starts[ray_id];
}
int64_t base, last;
if (ray_segments.is_batched) {
base = ray_id * ray_segments.n_edges_per_ray;
last = base + ray_segments.n_edges_per_ray - 1;
} else {
base = ray_segments.chunk_starts[ray_id];
last = base + ray_segments.chunk_cnts[ray_id] - 1;
}
scalar_t start_u = accum_weights[p0];
scalar_t end_u = accum_weights[p1];
scalar_t start_t = ts[p0];
scalar_t end_t = ts[p1];
int64_t base_out;
if (intervals.is_batched) {
base_out = ray_id * intervals.n_edges_per_ray;
} else {
base_out = intervals.chunk_starts[ray_id];
}
if (p0 == p1) {
if (p0 == end_bd - 1)
ts_out[tid] = end_t;
else
ts_out[tid] = start_t;
} else if (end_u - start_u < 1e-20f) {
ts_out[tid] = (start_t + end_t) * 0.5f;
float t_min = ray_segments.vals[base];
float t_max = ray_segments.vals[last];
if (sid == 0) {
float t = samples.vals[tid];
float t_next = samples.vals[tid + 1]; // FIXME: out of bounds?
float half_width = (t_next - t) * 0.5f;
intervals.vals[base_out] = fmaxf(t - half_width, t_min);
if (!intervals.is_batched) {
intervals.ray_indices[base_out] = ray_id;
intervals.is_left[base_out] = true;
intervals.is_right[base_out] = false;
}
} else {
scalar_t scaling = (end_t - start_t) / (end_u - start_u);
scalar_t t = (u - start_u) * scaling + start_t;
ts_out[tid] = t;
float t = samples.vals[tid];
float t_prev = samples.vals[tid - 1];
float t_edge = (t + t_prev) * 0.5f;
int64_t idx = base_out + sid;
intervals.vals[idx] = t_edge;
if (!intervals.is_batched) {
intervals.ray_indices[idx] = ray_id;
intervals.is_left[idx] = true;
intervals.is_right[idx] = true;
}
if (sid == n_samples - 1) {
float half_width = (t - t_prev) * 0.5f;
intervals.vals[idx + 1] = fminf(t + half_width, t_max);
if (!intervals.is_batched) {
intervals.ray_indices[idx + 1] = ray_id;
intervals.is_left[idx + 1] = false;
intervals.is_right[idx + 1] = true;
}
}
}
}
}
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt) // [n_rays]
/* kernels for searchsorted */
__global__ void searchsorted_kernel(
PackedRaySegmentsSpec query,
PackedRaySegmentsSpec key,
// outputs
int64_t *ids_left,
int64_t *ids_right)
{
DEVICE_GUARD(ts);
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < query.n_edges; tid += blockDim.x * gridDim.x)
{
int32_t ray_id;
if (query.is_batched) {
ray_id = tid / query.n_edges_per_ray;
} else {
if (query.ray_indices == nullptr) {
ray_id = binary_search_chunk_id(tid, query.n_rays, query.chunk_starts) - 1;
} else {
ray_id = query.ray_indices[tid];
}
}
int64_t base, last;
if (key.is_batched) {
base = ray_id * key.n_edges_per_ray;
last = base + key.n_edges_per_ray - 1;
} else {
base = key.chunk_starts[ray_id];
last = base + key.chunk_cnts[ray_id] - 1;
}
CHECK_INPUT(ts);
CHECK_INPUT(weights);
// searchsorted with "right" option:
// i.e. key.vals[p - 1] <= query.vals[tid] < key.vals[p]
int64_t p = upper_bound<float>(key.vals, base, last, query.vals[tid], nullptr);
if (query.is_batched) {
ids_left[tid] = max(min(p - 1, last), base) - base;
ids_right[tid] = max(min(p, last), base) - base;
} else {
ids_left[tid] = max(min(p - 1, last), base);
ids_right[tid] = max(min(p, last), base);
}
}
}
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
TORCH_CHECK(ts.size(1) == weights.size(1) + 1);
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
} // namespace device
} // namespace
// Return flattend RaySegmentsSpec because n_intervels_per_ray is defined per ray.
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments, // [..., n_edges_per_ray] or flattend
torch::Tensor cdfs, // [..., n_edges_per_ray] or flattend
torch::Tensor n_intervels_per_ray, // [...] or flattend
bool stratified)
{
DEVICE_GUARD(cdfs);
ray_segments.check();
CHECK_INPUT(cdfs);
CHECK_INPUT(n_intervels_per_ray);
TORCH_CHECK(cdfs.numel() == ray_segments.vals.numel());
if (padding > 0.f)
{
weights = weights + padding;
}
weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
torch::Tensor ts_out = torch::full({ts.size(0), n_samples}, -1.0f, ts.options());
int64_t numel = ts_out.numel();
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads, blocks;
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......@@ -212,142 +318,139 @@ torch::Tensor pdf_sampling(
rng_engine_inputs = gen->philox_cuda_state(4);
}
AT_DISPATCH_ALL_TYPES(
ts.scalar_type(),
"pdf_sampling",
([&]
{ pdf_sampling_kernel<scalar_t><<<grid, block, 0, stream>>>(
rng_engine_inputs,
ts.size(1), /* n_samples_in */
nullptr, /* info_ts */
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr, /* masks */
// output samples
RaySegmentsSpec samples;
samples.chunk_cnts = n_intervels_per_ray.to(n_intervels_per_ray.options().dtype(torch::kLong));
samples.memalloc_data(false, false); // no need boolen masks, no need to zero init.
int64_t n_samples = samples.vals.numel();
// step 1. compute the ray_indices and samples
threads = dim3(min(max_threads, n_samples));
blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_samples, threads.x)));
device::importance_sampling_kernel<<<blocks, threads, 0, stream>>>(
// cdfs
device::PackedRaySegmentsSpec(ray_segments),
cdfs.data_ptr<float>(),
// jittering
stratified,
single_jitter,
numel, /* numel */
ts_out.size(1), /* n_samples_out */
ts_out.data_ptr<scalar_t>() /* ts_out */
); }));
return ts_out; // [n_rays, n_samples_out]
rng_engine_inputs,
// output samples
device::PackedRaySegmentsSpec(samples));
// output ray segments
RaySegmentsSpec intervals;
intervals.chunk_cnts = (
(samples.chunk_cnts + 1) * (samples.chunk_cnts > 0)).to(samples.chunk_cnts.options());
intervals.memalloc_data(true, true); // need the boolen masks, need to zero init.
// step 2. compute the intervals.
device::compute_intervels_kernel<<<blocks, threads, 0, stream>>>(
// samples
device::PackedRaySegmentsSpec(ray_segments),
device::PackedRaySegmentsSpec(samples),
// output intervals
device::PackedRaySegmentsSpec(intervals));
return {intervals, samples};
}
template <typename scalar_t>
__global__ void pdf_readout_kernel(
const int64_t n_rays,
// keys
const int64_t n_samples_in,
const scalar_t *ts, // [n_rays, n_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in]
const bool *masks, // [n_rays]
// query
const int64_t n_samples_out,
const scalar_t *ts_out, // [n_rays, n_samples_out]
const bool *masks_out, // [n_rays]
scalar_t *weights_out) // [n_rays, n_samples_out - 1]
// Return batched RaySegmentsSpec because n_intervels_per_ray is same across rays.
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments, // [..., n_edges_per_ray] or flattend
torch::Tensor cdfs, // [..., n_edges_per_ray] or flattend
int64_t n_intervels_per_ray,
bool stratified)
{
int64_t n_bins_out = n_samples_out - 1;
int64_t numel = n_bins_out * n_rays;
DEVICE_GUARD(cdfs);
ray_segments.check();
CHECK_INPUT(cdfs);
TORCH_CHECK(cdfs.numel() == ray_segments.vals.numel());
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads, blocks;
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState rng_engine_inputs;
{
int64_t ray_id = tid / n_bins_out;
if (masks_out != nullptr && !masks_out[ray_id]) {
// We don't care about this query ray.
weights_out[tid] = 0.0f;
continue;
}
if (masks != nullptr && !masks[ray_id]) {
// We don't have the values for the key ray. In this case we consider the key ray
// is all-zero.
weights_out[tid] = 0.0f;
continue;
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(4);
}
// search range in ts
int64_t start_bd = ray_id * n_samples_in;
int64_t end_bd = start_bd + n_samples_in;
// index in ts_out
int64_t id0 = tid + ray_id;
int64_t id1 = id0 + 1;
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos0 = upper_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id0], nullptr);
pos0 = max(min(pos0 - 1, end_bd-1), start_bd);
// searchsorted with "left" option:
// i.e. accum_weights[pos - 1] < u <= accum_weights[pos]
int64_t pos1 = lower_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id1], nullptr);
pos1 = max(min(pos1, end_bd-1), start_bd);
// outer
scalar_t outer = accum_weights[pos1] - accum_weights[pos0];
weights_out[tid] = outer;
RaySegmentsSpec samples, intervals;
if (ray_segments.vals.ndimension() > 1){ // batched input
auto data_size = ray_segments.vals.sizes().vec();
data_size.back() = n_intervels_per_ray;
samples.vals = torch::empty(data_size, cdfs.options());
data_size.back() = n_intervels_per_ray + 1;
intervals.vals = torch::empty(data_size, cdfs.options());
} else { // flattend input
int64_t n_rays = ray_segments.chunk_cnts.numel();
samples.vals = torch::empty({n_rays, n_intervels_per_ray}, cdfs.options());
intervals.vals = torch::empty({n_rays, n_intervels_per_ray + 1}, cdfs.options());
}
int64_t n_samples = samples.vals.numel();
// step 1. compute the ray_indices and samples
threads = dim3(min(max_threads, n_samples));
blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_samples, threads.x)));
device::importance_sampling_kernel<<<blocks, threads, 0, stream>>>(
// cdfs
device::PackedRaySegmentsSpec(ray_segments),
cdfs.data_ptr<float>(),
// jittering
stratified,
rng_engine_inputs,
// output samples
device::PackedRaySegmentsSpec(samples));
// step 2. compute the intervals.
device::compute_intervels_kernel<<<blocks, threads, 0, stream>>>(
// samples
device::PackedRaySegmentsSpec(ray_segments),
device::PackedRaySegmentsSpec(samples),
// output intervals
device::PackedRaySegmentsSpec(intervals));
return {intervals, samples};
}
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt) // [n_rays, n_samples_out]
// Find two indices {left, right} for each item in query,
// such that: key.vals[left] <= query.vals < key.vals[right]
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key)
{
DEVICE_GUARD(ts);
CHECK_INPUT(ts);
CHECK_INPUT(weights);
DEVICE_GUARD(query.vals);
query.check();
key.check();
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
int64_t n_rays = ts.size(0);
int64_t n_samples_in = ts.size(1);
int64_t n_samples_out = ts_out.size(1);
int64_t n_bins_out = n_samples_out - 1;
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
c10::MaybeOwned<torch::Tensor> masks_out_maybe_owned = at::borrow_from_optional_tensor(masks_out_opt);
const torch::Tensor& masks_out = *masks_out_maybe_owned;
// weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
// outputs
int64_t n_edges = query.vals.numel();
torch::Tensor weights_out = torch::empty({n_rays, n_bins_out}, weights.options());
torch::Tensor ids_left = torch::empty(
query.vals.sizes(), query.vals.options().dtype(torch::kLong));
torch::Tensor ids_right = torch::empty(
query.vals.sizes(), query.vals.options().dtype(torch::kLong));
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t numel = weights_out.numel();
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
AT_DISPATCH_ALL_TYPES(
weights.scalar_type(),
"pdf_readout",
([&]
{ pdf_readout_kernel<scalar_t><<<grid, block, 0, stream>>>(
n_rays,
n_samples_in,
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr,
n_samples_out,
ts_out.data_ptr<scalar_t>(), /* ts_out */
masks_out.defined() ? masks_out.data_ptr<bool>() : nullptr,
weights_out.data_ptr<scalar_t>()
); }));
return weights_out; // [n_rays, n_bins_out]
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, n_edges));
dim3 blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_edges, threads.x)));
device::searchsorted_kernel<<<blocks, threads, 0, stream>>>(
device::PackedRaySegmentsSpec(query),
device::PackedRaySegmentsSpec(key),
// outputs
ids_left.data_ptr<int64_t>(),
ids_right.data_ptr<int64_t>());
return {ids_left, ids_right};
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type);
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value);
// cub implementations: parallel across samples
bool is_cub_available() {
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
// naive implementations: parallel across rays
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas);
// pdf
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt); // [n_rays]
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt); // [n_rays, n_samples_out]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// contraction
py::enum_<ContractionType>(m, "ContractionType")
.value("AABB", ContractionType::AABB)
.value("UN_BOUNDED_TANH", ContractionType::UN_BOUNDED_TANH)
.value("UN_BOUNDED_SPHERE", ContractionType::UN_BOUNDED_SPHERE);
m.def("contract", &contract);
m.def("contract_inv", &contract_inv);
// grid
m.def("grid_query", &grid_query);
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
// rendering
m.def("is_cub_available", is_cub_available);
m.def("transmittance_from_sigma_forward_cub", transmittance_from_sigma_forward_cub);
m.def("transmittance_from_sigma_backward_cub", transmittance_from_sigma_backward_cub);
m.def("transmittance_from_alpha_forward_cub", transmittance_from_alpha_forward_cub);
m.def("transmittance_from_alpha_backward_cub", transmittance_from_alpha_backward_cub);
m.def("transmittance_from_sigma_forward_naive", transmittance_from_sigma_forward_naive);
m.def("transmittance_from_sigma_backward_naive", transmittance_from_sigma_backward_naive);
m.def("transmittance_from_alpha_forward_naive", transmittance_from_alpha_forward_naive);
m.def("transmittance_from_alpha_backward_naive", transmittance_from_alpha_backward_naive);
m.def("weight_from_sigma_forward_naive", weight_from_sigma_forward_naive);
m.def("weight_from_sigma_backward_naive", weight_from_sigma_backward_naive);
m.def("weight_from_alpha_forward_naive", weight_from_alpha_forward_naive);
m.def("weight_from_alpha_backward_naive", weight_from_alpha_backward_naive);
// pack & unpack
m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask);
// pdf
m.def("pdf_sampling", &pdf_sampling);
m.def("pdf_readout", &pdf_readout);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
inline __device__ __host__ float calc_dt(
const float t, const float cone_angle,
const float dt_min, const float dt_max)
{
return clamp(t * cone_angle, dt_min, dt_max);
}
inline __device__ __host__ int mip_level(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
if (type != ContractionType::AABB)
{
// mip level should be always zero if not using AABB
return 0;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, ContractionType::AABB);
float3 scale = fabs(xyz_unit - 0.5);
float maxval = fmaxf(fmaxf(scale.x, scale.y), scale.z);
// if maxval is almost zero, it will trigger frexpf to output 0
// for exponent, which is not what we want.
maxval = fmaxf(maxval, 0.1);
int exponent;
frexpf(maxval, &exponent);
int mip = max(0, exponent + 1);
return mip;
}
inline __device__ __host__ int grid_idx_at(
const float3 xyz_unit, const int3 grid_res)
{
// xyz should be always in [0, 1]^3.
int3 ixyz = make_int3(xyz_unit * make_float3(grid_res));
ixyz = clamp(ixyz, make_int3(0, 0, 0), grid_res - 1);
int3 grid_offset = make_int3(grid_res.y * grid_res.z, grid_res.z, 1);
int idx = dot(ixyz, grid_offset);
return idx;
}
template <typename scalar_t>
inline __device__ __host__ scalar_t grid_occupied_at(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
ContractionType type, int mip,
const int grid_nlvl, const int3 grid_res, const scalar_t *grid_value)
{
if (type == ContractionType::AABB && mip >= grid_nlvl)
{
return false;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type);
xyz_unit = (xyz_unit - 0.5) * scalbnf(1.0f, -mip) + 0.5;
int idx = grid_idx_at(xyz_unit, grid_res) + mip * grid_res.x * grid_res.y * grid_res.z;
return grid_value[idx];
}
// dda like step
inline __device__ __host__ float distance_to_next_voxel(
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res)
{
float scaling = scalbnf(1.0f, mip);
float3 _roi_mid = (roi_min + roi_max) * 0.5;
float3 _roi_rad = (roi_max - roi_min) * 0.5;
float3 _roi_min = _roi_mid - _roi_rad * scaling;
float3 _roi_max = _roi_mid + _roi_rad * scaling;
float3 _occ_res = make_float3(grid_res);
float3 _xyz = roi_to_unit(xyz, _roi_min, _roi_max) * _occ_res;
float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (_roi_max - _roi_min);
float t = min(min(txyz.x, txyz.y), txyz.z);
return fmaxf(t, 0.0f);
}
inline __device__ __host__ float advance_to_next_voxel(
const float t, const float dt_min,
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res, const float far)
{
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res);
t_target = min(t_target, far);
float _t = t;
do
{
_t += dt_min;
} while (_t < t_target);
return _t;
}
// -------------------------------------------------------------------------------
// Raymarching
// -------------------------------------------------------------------------------
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int64_t *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float grid_cell_scale = fmaxf(fmaxf(
(roi_max.x - roi_min.x) / grid_res.x * scalbnf(1.732f, grid_nlvl - 1),
(roi_max.y - roi_min.y) / grid_res.y * scalbnf(1.732f, grid_nlvl - 1)),
(roi_max.z - roi_min.z) / grid_res.z * scalbnf(1.732f, grid_nlvl - 1));
// TODO: compute dt_max from occ resolution.
float dt_min = step_size;
float dt_max;
if (type == ContractionType::AABB) {
// compute dt_max from occ grid resolution.
dt_max = grid_cell_scale;
} else {
dt_max = 1e10f;
}
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
// current center
const float3 xyz = origin + t_mid * dir;
// current mip level
const int mip = mip_level(xyz, roi_min, roi_max, type);
if (mip >= grid_nlvl) {
// out of grid
break;
}
if (grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_binary))
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
else
{
// march to next sample
switch (type)
{
case ContractionType::AABB:
// no contraction
t_mid = advance_to_next_voxel(
t_mid, dt_min, xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res, far);
dt = calc_dt(t_mid, cone_angle, dt_min, dt_max);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
break;
default:
// any type of scene contraction does not work with DDA.
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
break;
}
}
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(roi);
CHECK_INPUT(grid_binary);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6)
TORCH_CHECK(grid_binary.ndimension() == 4)
const int n_rays = rays_o.size(0);
const int grid_nlvl = grid_binary.size(0);
const int3 grid_res = make_int3(
grid_binary.size(1), grid_binary.size(2), grid_binary.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options().dtype(torch::kLong));
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int64_t>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
// ----------------------------------------------------------------------------
// Query the occupancy grid
// ----------------------------------------------------------------------------
template <typename scalar_t>
__global__ void query_occ_kernel(
// rays info
const uint32_t n_samples,
const float *samples, // shape (n_samples, 3)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// outputs
scalar_t *occs)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
const int mip = mip_level(xyz, roi_min, roi_max, type);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_value);
return;
}
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int grid_nlvl = grid_value.size(0);
const int3 grid_res = make_int3(
grid_value.size(1), grid_value.size(2), grid_value.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::empty({n_samples}, grid_value.options());
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Bool,
occs.scalar_type(),
"grid_query",
([&]
{ query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// grid
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_value.data_ptr<scalar_t>(),
type,
// outputs
occs.data_ptr<scalar_t>()); }));
return occs;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void transmittance_from_sigma_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
transmittance += base;
// accumulation
float cumsum = 0.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = __expf(-cumsum);
cumsum += sigmas[j] * (ends[j] - starts[j]);
}
// // another way to impl:
// float T = 1.f;
// for (int j = 0; j < steps; ++j)
// {
// const float delta = ends[j] - starts[j];
// const float alpha = 1.f - __expf(-sigmas[j] * delta);
// transmittance[j] = T;
// T *= (1.f - alpha);
// }
return;
}
__global__ void transmittance_from_sigma_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *sigmas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
transmittance += base;
transmittance_grad += base;
starts += base;
ends += base;
sigmas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
sigmas_grad[j] = cumsum * (ends[j] - starts[j]);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
__global__ void transmittance_from_alpha_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
// accumulation
float T = 1.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = T;
T *= (1.0f - alphas[j]);
}
return;
}
__global__ void transmittance_from_alpha_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *alphas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
transmittance_grad += base;
alphas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
alphas_grad[j] = cumsum / fmax(1.0f - alphas[j], 1e-10f);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(sigmas);
// parallel across rays
transmittance_from_sigma_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor sigmas_grad = torch::empty_like(transmittance);
// parallel across rays
transmittance_from_sigma_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
sigmas_grad.data_ptr<float>());
return sigmas_grad;
}
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info, torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(packed_info.ndimension() == 2);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor alphas_grad = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
alphas_grad.data_ptr<float>());
return alphas_grad;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
// CUB is supported in CUDA >= 11.0
// ExclusiveScanByKey is supported in CUB >= 1.15.0 (CUDA >= 11.6)
// See: https://github.com/NVIDIA/cub/tree/main#releases
#include "include/helpers_cuda.h"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
#endif
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
// parallel across samples
torch::Tensor sigmas_dt = sigmas * (ends - starts);
torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
ray_indices.data_ptr<int64_t>(),
sigmas_dt.data_ptr<float>(),
sigmas_dt_cumsum.data_ptr<float>(),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor transmittance = (-sigmas_dt_cumsum).exp();
return transmittance;
}
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
// parallel across samples
torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance;
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor sigmas_grad = sigmas_dt_grad * (ends - starts);
return sigmas_grad;
}
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(alphas);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(ray_indices.ndimension() == 1);
const uint32_t n_samples = alphas.size(0);
// parallel across samples
torch::Tensor transmittance = torch::empty_like(alphas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
ray_indices.data_ptr<int64_t>(),
(1.0f - alphas).data_ptr<float>(),
transmittance.data_ptr<float>(),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
return transmittance;
}
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
// parallel across samples
torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance;
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor alphas_grad = sigmas_dt_grad / (1.0f - alphas).clamp_min(1e-10f);
return alphas_grad;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void weight_from_sigma_forward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
// outputs
float *weights)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
weights += base;
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float delta = ends[j] - starts[j];
const float alpha = 1.f - __expf(-sigmas[j] * delta);
weights[j] = alpha * T;
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_sigma_backward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
const float *weights,
const float *grad_weights,
// outputs
float *grad_sigmas)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
float accum = 0;
for (int j = 0; j < steps; ++j)
{
accum += grad_weights[j] * weights[j];
}
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float delta = ends[j] - starts[j];
const float alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_alpha_forward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *alphas,
// outputs
float *weights)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
weights += base;
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float alpha = alphas[j];
weights[j] = alpha * T;
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_alpha_backward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *alphas,
const float *weights,
const float *grad_weights,
// outputs
float *grad_alphas)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
weights += base;
grad_weights += base;
grad_alphas += base;
float accum = 0;
for (int j = 0; j < steps; ++j)
{
accum += grad_weights[j] * weights[j];
}
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float alpha = alphas[j];
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
return;
}
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::empty_like(sigmas);
weight_from_sigma_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
// outputs
weights.data_ptr<float>());
return weights;
}
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(weights);
CHECK_INPUT(grad_weights);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::empty_like(sigmas);
weight_from_sigma_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
weights.data_ptr<float>(),
grad_weights.data_ptr<float>(),
// outputs
grad_sigmas.data_ptr<float>());
return grad_sigmas;
}
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info, torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::empty_like(alphas);
weight_from_alpha_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
// outputs
weights.data_ptr<float>());
return weights;
}
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
CHECK_INPUT(weights);
CHECK_INPUT(grad_weights);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_alphas = torch::empty_like(alphas);
weight_from_alpha_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
weights.data_ptr<float>(),
grad_weights.data_ptr<float>(),
// outputs
grad_alphas.data_ptr<float>());
return grad_alphas;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <thrust/iterator/reverse_iterator.h>
#include "include/utils_scan.cuh"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
#endif
namespace {
namespace device {
#if CUB_SUPPORTS_SCAN_BY_KEY()
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
} // namespace device
} // namespace
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward)
{
DEVICE_GUARD(inputs);
torch::Tensor outputs = torch::empty_like(inputs);
int64_t n_items = inputs.size(0);
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward)
device::exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_items),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_items),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_items),
n_items);
else
device::exclusive_sum_by_key(
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_items);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
if (backward)
TORCH_CHECK(~normalize); // backward does not support normalize yet.
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
normalize);
} else {
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
0.f,
std::plus<float>(),
normalize);
}
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
if (backward)
TORCH_CHECK(~normalize); // backward does not support normalize yet.
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
normalize);
} else {
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
0.f,
std::plus<float>(),
normalize);
}
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
1.f,
std::multiplies<float>(),
false);
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
false);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
cudaGetLastError();
return grad_inputs;
}
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
1.f,
std::multiplies<float>(),
false);
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
false);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
cudaGetLastError();
return grad_inputs;
}
\ No newline at end of file
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from dataclasses import dataclass
from typing import Optional
import torch
from . import cuda as _C
@dataclass
class RaySamples:
"""Ray samples that supports batched and flattened data.
Note:
When `vals` is flattened, either `packed_info` or `ray_indices` must
be provided.
Args:
vals: Batched data with shape (n_rays, n_samples) or flattened data
with shape (all_samples,)
packed_info: Optional. A tensor of shape (n_rays, 2) that specifies
the start and count of each chunk in flattened `vals`, with in
total n_rays chunks. Only needed when `vals` is flattened.
ray_indices: Optional. A tensor of shape (all_samples,) that specifies
the ray index of each sample. Only needed when `vals` is flattened.
Examples:
.. code-block:: python
>>> # Batched data
>>> ray_samples = RaySamples(torch.rand(10, 100))
>>> # Flattened data
>>> ray_samples = RaySamples(
>>> torch.rand(1000),
>>> packed_info=torch.tensor([[0, 100], [100, 200], [300, 700]]),
>>> )
"""
vals: torch.Tensor
packed_info: Optional[torch.Tensor] = None
ray_indices: Optional[torch.Tensor] = None
def _to_cpp(self):
"""
Generate object to pass to C++
"""
spec = _C.RaySegmentsSpec()
spec.vals = self.vals.contiguous()
if self.packed_info is not None:
spec.chunk_starts = self.packed_info[:, 0].contiguous()
if self.chunk_cnts is not None:
spec.chunk_cnts = self.packed_info[:, 1].contiguous()
if self.ray_indices is not None:
spec.ray_indices = self.ray_indices.contiguous()
return spec
@classmethod
def _from_cpp(cls, spec):
"""
Generate object from C++
"""
if spec.chunk_starts is not None and spec.chunk_cnts is not None:
packed_info = torch.stack([spec.chunk_starts, spec.chunk_cnts], -1)
else:
packed_info = None
ray_indices = spec.ray_indices
return cls(
vals=spec.vals, packed_info=packed_info, ray_indices=ray_indices
)
@property
def device(self) -> torch.device:
return self.vals.device
@dataclass
class RayIntervals:
"""Ray intervals that supports batched and flattened data.
Each interval is defined by two edges (left and right). The attribute `vals`
stores the edges of all intervals along the rays. The attributes `is_left`
and `is_right` are for indicating whether each edge is a left or right edge.
This class unifies the representation of both continuous and non-continuous ray
intervals.
Note:
When `vals` is flattened, either `packed_info` or `ray_indices` must
be provided. Also both `is_left` and `is_right` must be provided.
Args:
vals: Batched data with shape (n_rays, n_edges) or flattened data
with shape (all_edges,)
packed_info: Optional. A tensor of shape (n_rays, 2) that specifies
the start and count of each chunk in flattened `vals`, with in
total n_rays chunks. Only needed when `vals` is flattened.
ray_indices: Optional. A tensor of shape (all_edges,) that specifies
the ray index of each edge. Only needed when `vals` is flattened.
is_left: Optional. A boolen tensor of shape (all_edges,) that specifies
whether each edge is a left edge. Only needed when `vals` is flattened.
is_right: Optional. A boolen tensor of shape (all_edges,) that specifies
whether each edge is a right edge. Only needed when `vals` is flattened.
Examples:
.. code-block:: python
>>> # Batched data
>>> ray_intervals = RayIntervals(torch.rand(10, 100))
>>> # Flattened data
>>> ray_intervals = RayIntervals(
>>> torch.rand(6),
>>> packed_info=torch.tensor([[0, 2], [2, 0], [2, 4]]),
>>> is_left=torch.tensor([True, False, True, True, True, False]),
>>> is_right=torch.tensor([False, True, False, True, True, True]),
>>> )
"""
vals: torch.Tensor
packed_info: Optional[torch.Tensor] = None
ray_indices: Optional[torch.Tensor] = None
is_left: Optional[torch.Tensor] = None
is_right: Optional[torch.Tensor] = None
def _to_cpp(self):
"""
Generate object to pass to C++
"""
spec = _C.RaySegmentsSpec()
spec.vals = self.vals.contiguous()
if self.packed_info is not None:
spec.chunk_starts = self.packed_info[:, 0].contiguous()
if self.packed_info is not None:
spec.chunk_cnts = self.packed_info[:, 1].contiguous()
if self.ray_indices is not None:
spec.ray_indices = self.ray_indices.contiguous()
if self.is_left is not None:
spec.is_left = self.is_left.contiguous()
if self.is_right is not None:
spec.is_right = self.is_right.contiguous()
return spec
@classmethod
def _from_cpp(cls, spec):
"""
Generate object from C++
"""
if spec.chunk_starts is not None and spec.chunk_cnts is not None:
packed_info = torch.stack([spec.chunk_starts, spec.chunk_cnts], -1)
else:
packed_info = None
ray_indices = spec.ray_indices
is_left = spec.is_left
is_right = spec.is_right
return cls(
vals=spec.vals,
packed_info=packed_info,
ray_indices=ray_indices,
is_left=is_left,
is_right=is_right,
)
@property
def device(self) -> torch.device:
return self.vals.device
from typing import Any
import torch
import torch.nn as nn
class AbstractEstimator(nn.Module):
"""An abstract Transmittance Estimator class for Sampling."""
def __init__(self) -> None:
super().__init__()
self.register_buffer("_dummy", torch.empty(0), persistent=False)
@property
def device(self) -> torch.device:
return self._dummy.device
def sampling(self, *args, **kwargs) -> Any:
raise NotImplementedError
def update_every_n_steps(self, *args, **kwargs) -> None:
raise NotImplementedError
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from ..grid import _enlarge_aabb, traverse_grids
from ..volrend import (
render_visibility_from_alpha,
render_visibility_from_density,
)
from .base import AbstractEstimator
class OccGridEstimator(AbstractEstimator):
"""Occupancy grid transmittance estimator for spatial skipping.
References: "Instant Neural Graphics Primitives."
Args:
roi_aabb: The axis-aligned bounding box of the region of interest. Useful for mapping
the 3D space to the grid.
resolution: The resolution of the grid. If an integer is given, the grid is assumed to
be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
levels: The number of levels of the grid. Default: 1.
"""
DIM: int = 3
def __init__(
self,
roi_aabb: Union[List[int], Tensor],
resolution: Union[int, List[int], Tensor] = 128,
levels: int = 1,
**kwargs,
) -> None:
super().__init__()
if "contraction_type" in kwargs:
raise ValueError(
"`contraction_type` is not supported anymore for nerfacc >= 0.4.0."
)
# check the resolution is legal
if isinstance(resolution, int):
resolution = [resolution] * self.DIM
if isinstance(resolution, (list, tuple)):
resolution = torch.tensor(resolution, dtype=torch.int32)
assert isinstance(resolution, Tensor), f"Invalid type: {resolution}!"
assert resolution.shape[0] == self.DIM, f"Invalid shape: {resolution}!"
# check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)):
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32)
assert isinstance(roi_aabb, Tensor), f"Invalid type: {roi_aabb}!"
assert roi_aabb.shape[0] == self.DIM * 2, f"Invalid shape: {roi_aabb}!"
# multiple levels of aabbs
aabbs = torch.stack(
[_enlarge_aabb(roi_aabb, 2**i) for i in range(levels)], dim=0
)
# total number of voxels
self.cells_per_lvl = int(resolution.prod().item())
self.levels = levels
# Buffers
self.register_buffer("resolution", resolution) # [3]
self.register_buffer("aabbs", aabbs) # [n_aabbs, 6]
self.register_buffer(
"occs", torch.zeros(self.levels * self.cells_per_lvl)
)
self.register_buffer(
"binaries",
torch.zeros([levels] + resolution.tolist(), dtype=torch.bool),
)
# Grid coords & indices
grid_coords = _meshgrid3d(resolution).reshape(
self.cells_per_lvl, self.DIM
)
self.register_buffer("grid_coords", grid_coords, persistent=False)
grid_indices = torch.arange(self.cells_per_lvl)
self.register_buffer("grid_indices", grid_indices, persistent=False)
@torch.no_grad()
def sampling(
self,
# rays
rays_o: Tensor, # [n_rays, 3]
rays_d: Tensor, # [n_rays, 3]
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
near_plane: float = 0.0,
far_plane: float = 1e10,
# rendering options
render_step_size: float = 1e-3,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
stratified: bool = False,
cone_angle: float = 0.0,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Sampling with spatial skipping.
Note:
This function is not differentiable to any inputs.
Args:
rays_o: Ray origins of shape (n_rays, 3).
rays_d: Normalized ray directions of shape (n_rays, 3).
sigma_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `sigma_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation density values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
alpha_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `alpha_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation opacity values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
near_plane: Optional. Near plane distance. Default: 0.0.
far_plane: Optional. Far plane distance. Default: 1e10.
render_step_size: Step size for marching. Default: 1e-3.
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
cone_angle: Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
Returns:
A tuple of {LongTensor, Tensor, Tensor}:
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_samples).
- **t_starts**: Per-sample start distance. Tensor with shape (n_samples,).
- **t_ends**: Per-sample end distance. Tensor with shape (n_samples,).
Examples:
.. code-block:: python
>>> ray_indices, t_starts, t_ends = grid.sampling(
>>> rays_o, rays_d, render_step_size=1e-3)
>>> t_mid = (t_starts + t_ends) / 2.0
>>> sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
near_planes = torch.full_like(rays_o[..., 0], fill_value=near_plane)
far_planes = torch.full_like(rays_o[..., 0], fill_value=far_plane)
if stratified:
near_planes += torch.rand_like(near_planes) * render_step_size
intervals, samples = traverse_grids(
rays_o,
rays_d,
self.binaries,
self.aabbs,
near_planes=near_planes,
far_planes=far_planes,
step_size=render_step_size,
cone_angle=cone_angle,
)
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
ray_indices = samples.ray_indices
packed_info = samples.packed_info
# skip invisible space
if (alpha_thre > 0.0 or early_stop_eps > 0.0) and (
sigma_fn is not None or alpha_fn is not None
):
alpha_thre = min(alpha_thre, self.occs.mean().item())
# Compute visibility of the samples, and filter out invisible samples
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N,)! Got {}".format(sigmas.shape)
masks = render_visibility_from_density(
t_starts=t_starts,
t_ends=t_ends,
sigmas=sigmas,
packed_info=packed_info,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N,)! Got {}".format(alphas.shape)
masks = render_visibility_from_alpha(
alphas=alphas,
packed_info=packed_info,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
return ray_indices, t_starts, t_ends
@torch.no_grad()
def update_every_n_steps(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 1e-2,
ema_decay: float = 0.95,
warmup_steps: int = 256,
n: int = 16,
) -> None:
"""Update the estimator every n steps during training.
Args:
step: Current training step.
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. Default: 256.
n: Update the grid every n steps. Default: 16.
"""
if not self.training:
raise RuntimeError(
"You should only call this function only during training. "
"Please call _update() directly if you want to update the "
"field during inference."
)
if step % n == 0 and self.training:
self._update(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=occ_thre,
ema_decay=ema_decay,
warmup_steps=warmup_steps,
)
@torch.no_grad()
def _get_all_cells(self) -> List[Tensor]:
"""Returns all cells of the grid."""
return [self.grid_indices] * self.levels
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells."""
lvl_indices = []
for lvl in range(self.levels):
uniform_indices = torch.randint(
self.cells_per_lvl, (n,), device=self.device
)
occupied_indices = torch.nonzero(self.binaries[lvl].flatten())[:, 0]
if n < len(occupied_indices):
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad()
def _update(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 0.01,
ema_decay: float = 0.95,
warmup_steps: int = 256,
) -> None:
"""Update the occ field in the EMA way."""
# sample cells
if step < warmup_steps:
lvl_indices = self._get_all_cells()
else:
N = self.cells_per_lvl // 4
lvl_indices = self._sample_uniform_and_occupied_cells(N)
for lvl, indices in enumerate(lvl_indices):
# infer occupancy: density * step_size
grid_coords = self.grid_coords[indices]
x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution
# voxel coordinates [0, 1]^3 -> world
x = self.aabbs[lvl, :3] + x * (
self.aabbs[lvl, 3:] - self.aabbs[lvl, :3]
)
occ = occ_eval_fn(x).squeeze(-1)
# ema update
cell_ids = lvl * self.cells_per_lvl + indices
self.occs[cell_ids] = torch.maximum(
self.occs[cell_ids] * ema_decay, occ
)
# suppose to use scatter max but emperically it is almost the same.
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self.binaries = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self.binaries.shape)
def _meshgrid3d(
res: Tensor, device: Union[torch.device, str] = "cpu"
) -> Tensor:
"""Create 3D grid coordinates."""
assert len(res) == 3
res = res.tolist()
return torch.stack(
torch.meshgrid(
[
torch.arange(res[0], dtype=torch.long),
torch.arange(res[1], dtype=torch.long),
torch.arange(res[2], dtype=torch.long),
],
indexing="ij",
),
dim=-1,
).to(device)
from typing import Callable, List, Literal, Optional, Tuple
import torch
from torch import Tensor
from ..data_specs import RayIntervals
from ..pdf import importance_sampling, searchsorted
from ..volrend import render_transmittance_from_density
from .base import AbstractEstimator
class PropNetEstimator(AbstractEstimator):
"""Proposal network transmittance estimator.
References: "Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields."
Args:
optimizer: The optimizer to use for the proposal networks.
scheduler: The learning rate scheduler to use for the proposal networks.
"""
def __init__(
self,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
super().__init__()
self.optimizer = optimizer
self.scheduler = scheduler
self.prop_cache: List = []
@torch.no_grad()
def sampling(
self,
prop_sigma_fns: List[Callable],
prop_samples: List[int],
num_samples: int,
# rendering options
n_rays: int,
near_plane: float,
far_plane: float,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
# training options
stratified: bool = False,
requires_grad: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Sampling with CDFs from proposal networks.
Note:
When `requires_grad` is `True`, the gradients are allowed to flow
through the proposal networks, and the outputs of the proposal
networks are cached to update them later when calling `update_every_n_steps()`
Args:
prop_sigma_fns: Proposal network evaluate functions. It should be a list
of functions that take in samples {t_starts (n_rays, n_samples),
t_ends (n_rays, n_samples)} and returns the post-activation densities
(n_rays, n_samples).
prop_samples: Number of samples to draw from each proposal network. Should
be the same length as `prop_sigma_fns`.
num_samples: Number of samples to draw in the end.
n_rays: Number of rays.
near_plane: Near plane.
far_plane: Far plane.
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
"lindisp".
stratified: Whether to use stratified sampling. Default to `False`.
requires_grad: Whether to allow gradients to flow through the proposal
networks. Default to `False`.
Returns:
A tuple of {Tensor, Tensor}:
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
"""
assert len(prop_sigma_fns) == len(prop_samples), (
"The number of proposal networks and the number of samples "
"should be the same."
)
cdfs = torch.cat(
[
torch.zeros((n_rays, 1), device=self.device),
torch.ones((n_rays, 1), device=self.device),
],
dim=-1,
)
intervals = RayIntervals(vals=cdfs)
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
intervals, _ = importance_sampling(
intervals, cdfs, level_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
with torch.set_grad_enabled(requires_grad):
sigmas = level_fn(t_starts, t_ends)
assert sigmas.shape == t_starts.shape
trans, _ = render_transmittance_from_density(
t_starts, t_ends, sigmas
)
cdfs = 1.0 - torch.cat(
[trans, torch.zeros_like(trans[:, :1])], dim=-1
)
if requires_grad:
self.prop_cache.append((intervals, cdfs))
intervals, _ = importance_sampling(
intervals, cdfs, num_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
if requires_grad:
self.prop_cache.append((intervals, None))
return t_starts, t_ends
@torch.enable_grad()
def compute_loss(self, trans: Tensor, loss_scaler: float = 1.0) -> Tensor:
"""Compute the loss for the proposal networks.
Args:
trans: The transmittance of all samples. Shape (n_rays, num_samples).
loss_scaler: The loss scaler. Default to 1.0.
Returns:
The loss for the proposal networks.
"""
if len(self.prop_cache) == 0:
return torch.zeros((), device=self.device)
intervals, _ = self.prop_cache.pop()
# get cdfs at all edges of intervals
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
cdfs = cdfs.detach()
loss = 0.0
while self.prop_cache:
prop_intervals, prop_cdfs = self.prop_cache.pop()
loss += _pdf_loss(intervals, cdfs, prop_intervals, prop_cdfs).mean()
return loss * loss_scaler
@torch.enable_grad()
def update_every_n_steps(
self,
trans: Tensor,
requires_grad: bool = False,
loss_scaler: float = 1.0,
) -> float:
"""Update the estimator every n steps during training.
Args:
trans: The transmittance of all samples. Shape (n_rays, num_samples).
requires_grad: Whether to allow gradients to flow through the proposal
networks. Default to `False`.
loss_scaler: The loss scaler to use. Default to 1.0.
Returns:
The loss of the proposal networks for logging (a float scalar).
"""
if requires_grad:
return self._update(trans=trans, loss_scaler=loss_scaler)
else:
if self.scheduler is not None:
self.scheduler.step()
return 0.0
@torch.enable_grad()
def _update(self, trans: Tensor, loss_scaler: float = 1.0) -> float:
assert len(self.prop_cache) > 0
assert self.optimizer is not None, "No optimizer is provided."
loss = self.compute_loss(trans, loss_scaler)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
return loss.item()
def get_proposal_requires_grad_fn(
target: float = 5.0, num_steps: int = 1000
) -> Callable:
schedule = lambda s: min(s / num_steps, 1.0) * target
steps_since_last_grad = 0
def proposal_requires_grad_fn(step: int) -> bool:
nonlocal steps_since_last_grad
target_steps_since_last_grad = schedule(step)
requires_grad = steps_since_last_grad > target_steps_since_last_grad
if requires_grad:
steps_since_last_grad = 0
steps_since_last_grad += 1
return requires_grad
return proposal_requires_grad_fn
def _transform_stot(
transform_type: Literal["uniform", "lindisp"],
s_vals: torch.Tensor,
t_min: torch.Tensor,
t_max: torch.Tensor,
) -> torch.Tensor:
if transform_type == "uniform":
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
elif transform_type == "lindisp":
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
else:
raise ValueError(f"Unknown transform_type: {transform_type}")
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
return icontract_fn(s_vals)
def _pdf_loss(
segments_query: RayIntervals,
cdfs_query: torch.Tensor,
segments_key: RayIntervals,
cdfs_key: torch.Tensor,
eps: float = 1e-7,
) -> torch.Tensor:
ids_left, ids_right = searchsorted(segments_key, segments_query)
if segments_query.vals.dim() > 1:
w = cdfs_query[..., 1:] - cdfs_query[..., :-1]
ids_left = ids_left[..., :-1]
ids_right = ids_right[..., 1:]
else:
# TODO: not tested for this branch.
assert segments_query.is_left is not None
assert segments_query.is_right is not None
w = (
cdfs_query[segments_query.is_right]
- cdfs_query[segments_query.is_left]
)
ids_left = ids_left[segments_query.is_left]
ids_right = ids_right[segments_query.is_right]
w_outer = cdfs_key.gather(-1, ids_right) - cdfs_key.gather(-1, ids_left)
return torch.clip(w - w_outer, min=0) ** 2 / (w + eps)
def _outer(
t0_starts: torch.Tensor,
t0_ends: torch.Tensor,
t1_starts: torch.Tensor,
t1_ends: torch.Tensor,
y1: torch.Tensor,
) -> torch.Tensor:
"""
Args:
t0_starts: (..., S0).
t0_ends: (..., S0).
t1_starts: (..., S1).
t1_ends: (..., S1).
y1: (..., S1).
"""
cy1 = torch.cat(
[torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
)
idx_lo = (
torch.searchsorted(
t1_starts.contiguous(), t0_starts.contiguous(), side="right"
)
- 1
)
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(
t1_ends.contiguous(), t0_ends.contiguous(), side="right"
)
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def _lossfun_outer(
t: torch.Tensor,
w: torch.Tensor,
t_env: torch.Tensor,
w_env: torch.Tensor,
):
"""
Args:
t: interval edges, (..., S + 1).
w: weights, (..., S).
t_env: interval edges of the upper bound enveloping historgram, (..., S + 1).
w_env: weights that should upper bound the inner (t,w) histogram, (..., S).
"""
eps = torch.finfo(t.dtype).eps
w_outer = _outer(
t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env
)
return torch.clip(w - w_outer, min=0) ** 2 / (w + eps)
"""
Copyright (c) 2022 Ruilong Li @ UC Berkeley
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Callable, List, Union
from typing import Optional, Tuple
import torch
import torch.nn as nn
import nerfacc.cuda as _C
from torch import Tensor
from .contraction import ContractionType, contract_inv
# TODO: check torch.scatter_reduce_
# from torch_scatter import scatter_max
from . import cuda as _C
from .data_specs import RayIntervals, RaySamples
@torch.no_grad()
def query_grid(
samples: torch.Tensor,
grid_roi: torch.Tensor,
grid_values: torch.Tensor,
grid_type: ContractionType,
):
"""Query grid values given coordinates.
def ray_aabb_intersect(
rays_o: Tensor,
rays_d: Tensor,
aabbs: Tensor,
near_plane: float = -float("inf"),
far_plane: float = float("inf"),
miss_value: float = float("inf"),
) -> Tuple[Tensor, Tensor, Tensor]:
"""Ray-AABB intersection.
Args:
samples: (n_samples, 3) tensor of coordinates.
grid_roi: (6,) region of interest of the grid. Usually it should be
accquired from the grid itself using `grid.roi_aabb`.
grid_values: A 4D tensor of grid values in the shape of (nlvl, resx, resy, resz).
grid_type: Contraction type of the grid. Usually it should be
accquired from the grid itself using `grid.contraction_type`.
rays_o: (n_rays, 3) Ray origins.
rays_d: (n_rays, 3) Normalized ray directions.
aabbs: (m, 6) Axis-aligned bounding boxes {xmin, ymin, zmin, xmax, ymax, zmax}.
near_plane: Optional. Near plane. Default to -infinity.
far_plane: Optional. Far plane. Default to infinity.
miss_value: Optional. Value to use for tmin and tmax when there is no intersection.
Default to infinity.
Returns:
(n_samples) values for those samples queried from the grid.
t_mins: (n_rays, m) tmin for each ray-AABB pair.
t_maxs: (n_rays, m) tmax for each ray-AABB pair.
hits: (n_rays, m) whether each ray-AABB pair intersects.
"""
assert samples.dim() == 2 and samples.size(-1) == 3
assert grid_roi.dim() == 1 and grid_roi.size(0) == 6
assert grid_values.dim() == 4
assert isinstance(grid_type, ContractionType)
return _C.grid_query(
samples.contiguous(),
grid_roi.contiguous(),
grid_values.contiguous(),
grid_type.to_cpp_version(),
assert rays_o.ndim == 2 and rays_o.shape[-1] == 3
assert rays_d.ndim == 2 and rays_d.shape[-1] == 3
assert aabbs.ndim == 2 and aabbs.shape[-1] == 6
t_mins, t_maxs, hits = _C.ray_aabb_intersect(
rays_o.contiguous(),
rays_d.contiguous(),
aabbs.contiguous(),
near_plane,
far_plane,
miss_value,
)
return t_mins, t_maxs, hits
class Grid(nn.Module):
"""An abstract Grid class.
def _ray_aabb_intersect(
rays_o: Tensor,
rays_d: Tensor,
aabbs: Tensor,
near_plane: float = -float("inf"),
far_plane: float = float("inf"),
miss_value: float = float("inf"),
) -> Tuple[Tensor, Tensor, Tensor]:
"""Ray-AABB intersection.
The grid is used as a cache of the 3D space to indicate whether each voxel
area is important or not for the differentiable rendering process. The
ray marching function (see :func:`nerfacc.ray_marching`) would use the
grid to skip the unimportant voxel areas.
Functionally the same with `ray_aabb_intersect()`, but slower with pure Pytorch.
"""
To work with :func:`nerfacc.ray_marching`, three attributes must exist:
# Compute the minimum and maximum bounds of the AABBs
aabb_min = aabbs[:, :3]
aabb_max = aabbs[:, 3:]
- :attr:`roi_aabb`: The axis-aligned bounding box of the region of interest.
- :attr:`binary`: A 4D binarized tensor of shape {nlvl, resx, resy, resz}, \
with torch.bool data type.
- :attr:`contraction_type`: The contraction type of the grid, indicating how \
the 3D space is mapped to the grid.
"""
# Compute the intersection distances between the ray and each of the six AABB planes
t1 = (aabb_min[None, :, :] - rays_o[:, None, :]) / rays_d[:, None, :]
t2 = (aabb_max[None, :, :] - rays_o[:, None, :]) / rays_d[:, None, :]
def __init__(self, *args, **kwargs):
super().__init__()
self.register_buffer("_dummy", torch.empty(0), persistent=False)
# Compute the maximum tmin and minimum tmax for each AABB
t_mins = torch.max(torch.min(t1, t2), dim=-1)[0]
t_maxs = torch.min(torch.max(t1, t2), dim=-1)[0]
@property
def device(self) -> torch.device:
return self._dummy.device
# Compute whether each ray-AABB pair intersects
hits = (t_maxs > t_mins) & (t_maxs > 0)
@property
def roi_aabb(self) -> torch.Tensor:
"""The axis-aligned bounding box of the region of interest.
# Clip the tmin and tmax values to the near and far planes
t_mins = torch.clamp(t_mins, min=near_plane, max=far_plane)
t_maxs = torch.clamp(t_maxs, min=near_plane, max=far_plane)
Its is a shape (6,) tensor in the format of {minx, miny, minz, maxx, maxy, maxz}.
"""
if hasattr(self, "_roi_aabb"):
return getattr(self, "_roi_aabb")
else:
raise NotImplementedError("please set an attribute named _roi_aabb")
# Set the tmin and tmax values to miss_value if there is no intersection
t_mins = torch.where(hits, t_mins, miss_value)
t_maxs = torch.where(hits, t_maxs, miss_value)
@property
def binary(self) -> torch.Tensor:
"""A 4-dim binarized tensor with torch.bool data type.
return t_mins, t_maxs, hits
The tensor is of shape (nlvl, resx, resy, resz), in which each boolen value
represents whether the corresponding voxel should be kept or not.
"""
if hasattr(self, "_binary"):
return getattr(self, "_binary")
else:
raise NotImplementedError("please set an attribute named _binary")
@property
def contraction_type(self) -> ContractionType:
"""The contraction type of the grid.
@torch.no_grad()
def traverse_grids(
# rays
rays_o: Tensor, # [n_rays, 3]
rays_d: Tensor, # [n_rays, 3]
# grids
binaries: Tensor, # [m, resx, resy, resz]
aabbs: Tensor, # [m, 6]
# options
near_planes: Optional[Tensor] = None, # [n_rays]
far_planes: Optional[Tensor] = None, # [n_rays]
step_size: Optional[float] = 1e-3,
cone_angle: Optional[float] = 0.0,
) -> Tuple[RayIntervals, RaySamples]:
"""Ray Traversal within Multiple Grids.
Note:
This function is not differentiable to any inputs.
The contraction type is an indicator of how the 3D space is contracted
to this voxel grid. See :class:`nerfacc.ContractionType` for more details.
Args:
rays_o: (n_rays, 3) Ray origins.
rays_d: (n_rays, 3) Normalized ray directions.
binary_grids: (m, resx, resy, resz) Multiple binary grids with the same resolution.
aabbs: (m, 6) Axis-aligned bounding boxes {xmin, ymin, zmin, xmax, ymax, zmax}.
near_planes: Optional. (n_rays,) Near planes for the traversal to start. Default to 0.
far_planes: Optional. (n_rays,) Far planes for the traversal to end. Default to infinity.
step_size: Optional. Step size for ray traversal. Default to 1e-3.
cone_angle: Optional. Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
Returns:
A :class:`RayIntervals` object containing the intervals of the ray traversal, and
a :class:`RaySamples` object containing the samples within each interval.
"""
if hasattr(self, "_contraction_type"):
return getattr(self, "_contraction_type")
else:
raise NotImplementedError(
"please set an attribute named _contraction_type"
# Compute ray aabb intersection for all levels of grid. [n_rays, m]
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, aabbs)
if near_planes is None:
near_planes = torch.zeros_like(rays_o[:, 0])
if far_planes is None:
far_planes = torch.full_like(rays_o[:, 0], float("inf"))
intervals, samples = _C.traverse_grids(
# rays
rays_o.contiguous(), # [n_rays, 3]
rays_d.contiguous(), # [n_rays, 3]
# grids
binaries.contiguous(), # [m, resx, resy, resz]
aabbs.contiguous(), # [m, 6]
# intersections
t_mins.contiguous(), # [n_rays, m]
t_maxs.contiguous(), # [n_rays, m]
hits.contiguous(), # [n_rays, m]
# options
near_planes.contiguous(), # [n_rays]
far_planes.contiguous(), # [n_rays]
step_size,
cone_angle,
True,
True,
)
return RayIntervals._from_cpp(intervals), RaySamples._from_cpp(samples)
class OccupancyGrid(Grid):
"""Occupancy grid: whether each voxel area is occupied or not.
def _enlarge_aabb(aabb, factor: float) -> Tensor:
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
Args:
roi_aabb: The axis-aligned bounding box of the region of interest. Useful for mapping
the 3D space to the grid.
resolution: The resolution of the grid. If an integer is given, the grid is assumed to
be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
contraction_type: The contraction type of the grid. See :class:`nerfacc.ContractionType`
for more details. Default: :attr:`nerfacc.ContractionType.AABB`.
levels: The number of levels of the grid. Default: 1.
"""
NUM_DIM: int = 3
def __init__(
self,
roi_aabb: Union[List[int], torch.Tensor],
resolution: Union[int, List[int], torch.Tensor] = 128,
contraction_type: ContractionType = ContractionType.AABB,
levels: int = 1,
) -> None:
super().__init__()
if isinstance(resolution, int):
resolution = [resolution] * self.NUM_DIM
if isinstance(resolution, (list, tuple)):
resolution = torch.tensor(resolution, dtype=torch.int32)
assert isinstance(
resolution, torch.Tensor
), f"Invalid type: {type(resolution)}"
assert resolution.shape == (
self.NUM_DIM,
), f"Invalid shape: {resolution.shape}"
if isinstance(roi_aabb, (list, tuple)):
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32)
assert isinstance(
roi_aabb, torch.Tensor
), f"Invalid type: {type(roi_aabb)}"
assert roi_aabb.shape == torch.Size(
[self.NUM_DIM * 2]
), f"Invalid shape: {roi_aabb.shape}"
if levels > 1:
assert (
contraction_type == ContractionType.AABB
), "For multi-res occupancy grid, contraction is not supported yet."
# total number of voxels
self.num_cells_per_lvl = int(resolution.prod().item())
self.levels = levels
# required attributes
self.register_buffer("_roi_aabb", roi_aabb)
self.register_buffer(
"_binary",
torch.zeros([levels] + resolution.tolist(), dtype=torch.bool),
)
self._contraction_type = contraction_type
# helper attributes
self.register_buffer("resolution", resolution)
self.register_buffer(
"occs", torch.zeros(self.levels * self.num_cells_per_lvl)
)
def _query(x: Tensor, data: Tensor, base_aabb: Tensor) -> Tensor:
"""
Query the grid values at the given points.
# Grid coords & indices
grid_coords = _meshgrid3d(resolution).reshape(
self.num_cells_per_lvl, self.NUM_DIM
)
self.register_buffer("grid_coords", grid_coords, persistent=False)
grid_indices = torch.arange(self.num_cells_per_lvl)
self.register_buffer("grid_indices", grid_indices, persistent=False)
@torch.no_grad()
def _get_all_cells(self) -> List[torch.Tensor]:
"""Returns all cells of the grid."""
return [self.grid_indices] * self.levels
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[torch.Tensor]:
"""Samples both n uniform and occupied cells."""
lvl_indices = []
for lvl in range(self.levels):
uniform_indices = torch.randint(
self.num_cells_per_lvl, (n,), device=self.device
)
occupied_indices = torch.nonzero(self._binary[lvl].flatten())[:, 0]
if n < len(occupied_indices):
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad()
def _update(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 0.01,
ema_decay: float = 0.95,
warmup_steps: int = 256,
) -> None:
"""Update the occ field in the EMA way."""
# sample cells
if step < warmup_steps:
lvl_indices = self._get_all_cells()
else:
N = self.num_cells_per_lvl // 4
lvl_indices = self._sample_uniform_and_occupied_cells(N)
for lvl, indices in enumerate(lvl_indices):
# infer occupancy: density * step_size
grid_coords = self.grid_coords[indices]
x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution
if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE:
# only the points inside the sphere are valid
mask = (x - 0.5).norm(dim=1) < 0.5
x = x[mask]
indices = indices[mask]
# voxel coordinates [0, 1]^3 -> world
x = contract_inv(
(x - 0.5) * (2**lvl) + 0.5,
roi=self._roi_aabb,
type=self._contraction_type,
)
occ = occ_eval_fn(x).squeeze(-1)
# ema update
cell_ids = lvl * self.num_cells_per_lvl + indices
self.occs[cell_ids] = torch.maximum(
self.occs[cell_ids] * ema_decay, occ
)
# suppose to use scatter max but emperically it is almost the same.
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self._binary = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self._binary.shape)
@torch.no_grad()
def every_n_step(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 1e-2,
ema_decay: float = 0.95,
warmup_steps: int = 256,
n: int = 16,
) -> None:
"""Update the grid every n steps during training.
This function assumes the aabbs of multiple grids are 2x scaled.
Args:
step: Current training step.
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. Default: 256.
n: Update the grid every n steps. Default: 16.
x: (N, 3) tensor of points to query.
data: (m, resx, resy, resz) tensor of grid values
base_aabb: (6,) aabb of base level grid.
"""
if not self.training:
raise RuntimeError(
"You should only call this function only during training. "
"Please call _update() directly if you want to update the "
"field during inference."
)
if step % n == 0 and self.training:
self._update(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=occ_thre,
ema_decay=ema_decay,
warmup_steps=warmup_steps,
)
# normalize so that the base_aabb is [0, 1]^3
aabb_min, aabb_max = torch.split(base_aabb, 3, dim=0)
x_norm = (x - aabb_min) / (aabb_max - aabb_min)
@torch.no_grad()
def query_occ(self, samples: torch.Tensor) -> torch.Tensor:
"""Query the occupancy field at the given samples.
# if maxval is almost zero, it will trigger frexpf to output 0
# for exponent, which is not what we want.
maxval = (x_norm - 0.5).abs().max(dim=-1).values
maxval = torch.clamp(maxval, min=0.1)
Args:
samples: Samples in the world coordinates. (n_samples, 3)
# compute the mip level
exponent = torch.frexp(maxval)[1].long()
mip = torch.clamp(exponent + 1, min=0)
selector = mip < data.shape[0]
Returns:
Occupancy values at the given samples. (n_samples,)
"""
return query_grid(
samples,
self._roi_aabb,
self.binary,
self.contraction_type,
)
# use the mip to re-normalize all points to [0, 1].
scale = 2**mip
x_unit = (x_norm - 0.5) / scale[:, None] + 0.5
# map to the grid index
resolution = torch.tensor(data.shape[1:], device=x.device)
ix = (x_unit * resolution).long()
ix = torch.clamp(ix, max=resolution - 1)
mip = torch.clamp(mip, max=data.shape[0] - 1)
def _meshgrid3d(
res: torch.Tensor, device: Union[torch.device, str] = "cpu"
) -> torch.Tensor:
"""Create 3D grid coordinates."""
assert len(res) == 3
res = res.tolist()
return torch.stack(
torch.meshgrid(
[
torch.arange(res[0], dtype=torch.long),
torch.arange(res[1], dtype=torch.long),
torch.arange(res[2], dtype=torch.long),
],
indexing="ij",
),
dim=-1,
).to(device)
return data[mip, ix[:, 0], ix[:, 1], ix[:, 2]] * selector, selector
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
import torch
from torch import Tensor
import nerfacc.cuda as _C
@torch.no_grad()
def ray_aabb_intersect(
rays_o: Tensor, rays_d: Tensor, aabb: Tensor
) -> Tuple[Tensor, Tensor]:
"""Ray AABB Test.
Note:
this function is not differentiable to any inputs.
Args:
rays_o: Ray origins of shape (n_rays, 3).
rays_d: Normalized ray directions of shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. \
Tensor with shape (6)
Returns:
Ray AABB intersection {t_min, t_max} with shape (n_rays) respectively. \
Note the t_min is clipped to minimum zero. 1e10 means no intersection.
Examples:
.. code-block:: python
aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device="cuda:0")
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb)
"""
if rays_o.is_cuda and rays_d.is_cuda and aabb.is_cuda:
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
aabb = aabb.contiguous()
t_min, t_max = _C.ray_aabb_intersect(rays_o, rays_d, aabb)
else:
raise NotImplementedError("Only support cuda inputs.")
return t_min, t_max
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Optional, Tuple
from typing import Optional
import torch
from torch import Tensor
import nerfacc.cuda as _C
@torch.no_grad()
def pack_info(ray_indices: Tensor, n_rays: Optional[int] = None) -> Tensor:
"""Pack `ray_indices` to `packed_info`. Useful for converting per sample data to per ray data.
def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Pack per-ray data (n_rays, n_samples, D) to (all_samples, D) based on mask.
Note:
this function is not differentiable to any inputs.
Args:
data: Tensor with shape (n_rays, n_samples, D).
mask: Boolen tensor with shape (n_rays, n_samples).
ray_indices: Ray indices of the samples. LongTensor with shape (n_sample).
n_rays: Number of rays. If None, it is inferred from `ray_indices`. Default is None.
Returns:
Tuple of Tensors including packed data (all_samples, D), \
and packed_info (n_rays, 2) which stores the start index of the sample,
and the number of samples kept for each ray. \
A LongTensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
Examples:
Example:
.. code-block:: python
data = torch.rand((10, 3, 4), device="cuda:0")
mask = data.rand((10, 3), dtype=torch.bool, device="cuda:0")
packed_data, packed_info = pack(data, mask)
print(packed_data.shape, packed_info.shape)
"""
assert data.dim() == 3, "data must be with shape of (n_rays, n_samples, D)."
assert (
mask.shape == data.shape[:2]
), "mask must be with shape of (n_rays, n_samples)."
assert mask.dtype == torch.bool, "mask must be a boolean tensor."
packed_data = data[mask]
num_steps = mask.sum(dim=-1, dtype=torch.int32)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int32)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
return packed_data, packed_info
@torch.no_grad()
def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor:
"""Pack `ray_indices` to `packed_info`. Useful for converting per sample data to per ray data.
Note:
this function is not differentiable to any inputs.
Args:
ray_indices: Ray index of each sample. LongTensor with shape (n_sample).
>>> ray_indices = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2], device="cuda")
>>> packed_info = pack_info(ray_indices, n_rays=3)
>>> packed_info
tensor([[0, 2], [2, 3], [5, 4]], device='cuda:0')
Returns:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
"""
assert (
ray_indices.dim() == 1
), "ray_indices must be a 1D tensor with shape (n_samples)."
if ray_indices.is_cuda:
ray_indices = ray_indices
device = ray_indices.device
dtype = ray_indices.dtype
if n_rays is None:
n_rays = int(ray_indices.max()) + 1
# else:
# assert n_rays > ray_indices.max()
src = torch.ones_like(ray_indices, dtype=torch.int)
num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int)
num_steps.scatter_add_(0, ray_indices, src)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
n_rays = ray_indices.max().item() + 1
chunk_cnts = torch.zeros((n_rays,), device=device, dtype=dtype)
chunk_cnts.index_add_(0, ray_indices, torch.ones_like(ray_indices))
chunk_starts = chunk_cnts.cumsum(dim=0, dtype=dtype) - chunk_cnts
packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
else:
raise NotImplementedError("Only support cuda inputs.")
return packed_info
@torch.no_grad()
def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note:
this function is not differentiable to any inputs.
Args:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns:
Ray index of each sample. LongTensor with shape (n_sample).
Examples:
.. code-block:: python
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with near far plane.
packed_info, t_starts, t_ends = ray_marching(
rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3
)
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info.
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype)
"""
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous(), n_samples)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices
def unpack_data(
packed_info: Tensor,
data: Tensor,
n_samples: Optional[int] = None,
pad_value: float = 0.0,
) -> Tensor:
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D).
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
data: Packed data to unpack. Tensor with shape (n_samples, D).
n_samples (int): Optional Number of samples per ray. If not provided, it \
will be inferred from the packed_info.
pad_value: Value to pad the unpacked data.
Returns:
Unpacked data (n_rays, n_samples, D).
Examples:
.. code-block:: python
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with aabb.
scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device="cuda:0")
packed_info, t_starts, t_ends = ray_marching(
rays_o, rays_d, scene_aabb=scene_aabb, render_step_size=1e-2
)
print(t_starts.shape) # torch.Size([all_samples, 1])
t_starts = unpack_data(packed_info, t_starts, n_samples=1024)
print(t_starts.shape) # torch.Size([128, 1024, 1])
"""
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
assert (
data.dim() == 2
), "data must be a 2D tensor with shape (n_samples, D)."
if n_samples is None:
n_samples = packed_info[:, 1].max().item()
return _UnpackData.apply(packed_info, data, n_samples, pad_value)
class _UnpackData(torch.autograd.Function):
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D)."""
@staticmethod
def forward(
ctx,
packed_info: Tensor,
data: Tensor,
n_samples: int,
pad_value: float = 0.0,
) -> Tensor:
# shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous()
data = data.contiguous()
if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info)
ctx.n_samples = n_samples
return _C.unpack_data(packed_info, data, n_samples, pad_value)
@staticmethod
def backward(ctx, grad: Tensor):
# shape of the grad should be (n_rays, n_samples, D)
packed_info = ctx.saved_tensors[0]
n_samples = ctx.n_samples
mask = _C.unpack_info_to_mask(packed_info, n_samples)
packed_grad = grad[mask].contiguous()
return None, packed_grad, None, None
from typing import Optional
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple, Union
import torch
from torch import Tensor
import nerfacc.cuda as _C
from . import cuda as _C
from .data_specs import RayIntervals, RaySamples
class PDFOuter(torch.autograd.Function):
@staticmethod
def forward(
ctx,
ts: Tensor,
def searchsorted(
sorted_sequence: Union[RayIntervals, RaySamples],
values: Union[RayIntervals, RaySamples],
) -> Tuple[Tensor, Tensor]:
"""Searchsorted that supports flattened tensor.
This function returns {`ids_left`, `ids_right`} such that:
`sorted_sequence.vals.gather(-1, ids_left) <= values.vals < sorted_sequence.vals.gather(-1, ids_right)`
Note:
When values is out of range of sorted_sequence, we return the
corresponding ids as if the values is clipped to the range of
sorted_sequence. See the example below.
Args:
sorted_sequence: A :class:`RayIntervals` or :class:`RaySamples` object. We assume
the `sorted_sequence.vals` is acendingly sorted for each ray.
values: A :class:`RayIntervals` or :class:`RaySamples` object.
Returns:
A tuple of LongTensor:
- **ids_left**: A LongTensor with the same shape as `values.vals`.
- **ids_right**: A LongTensor with the same shape as `values.vals`.
Example:
>>> sorted_sequence = RayIntervals(
... vals=torch.tensor([0.0, 1.0, 0.0, 1.0, 2.0], device="cuda"),
... packed_info=torch.tensor([[0, 2], [2, 3]], device="cuda"),
... )
>>> values = RayIntervals(
... vals=torch.tensor([0.5, 1.5, 2.5], device="cuda"),
... packed_info=torch.tensor([[0, 1], [1, 2]], device="cuda"),
... )
>>> ids_left, ids_right = searchsorted(sorted_sequence, values)
>>> ids_left
tensor([0, 3, 3], device='cuda:0')
>>> ids_right
tensor([1, 4, 4], device='cuda:0')
>>> sorted_sequence.vals.gather(-1, ids_left)
tensor([0., 1., 1.], device='cuda:0')
>>> sorted_sequence.vals.gather(-1, ids_right)
tensor([1., 2., 2.], device='cuda:0')
"""
ids_left, ids_right = _C.searchsorted(
values._to_cpp(), sorted_sequence._to_cpp()
)
return ids_left, ids_right
def importance_sampling(
intervals: RayIntervals,
cdfs: Tensor,
n_intervals_per_ray: Union[Tensor, int],
stratified: bool = False,
) -> Tuple[RayIntervals, RaySamples]:
"""Importance sampling that supports flattened tensor.
Given a set of intervals and the corresponding CDFs at the interval edges,
this function performs inverse transform sampling to create a new set of
intervals and samples. Stratified sampling is also supported.
Args:
intervals: A :class:`RayIntervals` object that specifies the edges of the
intervals along the rays.
cdfs: The CDFs at the interval edges. It has the same shape as
`intervals.vals`.
n_intervals_per_ray: Resample each ray to have this many intervals.
If it is a tensor, it must be of shape (n_rays,). If it is an int,
it is broadcasted to all rays.
stratified: If True, perform stratified sampling.
Returns:
A tuple of {:class:`RayIntervals`, :class:`RaySamples`}:
- **intervals**: A :class:`RayIntervals` object. If `n_intervals_per_ray` is an int, \
`intervals.vals` will has the shape of (n_rays, n_intervals_per_ray + 1). \
If `n_intervals_per_ray` is a tensor, we assume each ray results \
in a different number of intervals. In this case, `intervals.vals` \
will has the shape of (all_edges,), the attributes `packed_info`, \
`ray_indices`, `is_left` and `is_right` will be accessable.
- **samples**: A :class:`RaySamples` object. If `n_intervals_per_ray` is an int, \
`samples.vals` will has the shape of (n_rays, n_intervals_per_ray). \
If `n_intervals_per_ray` is a tensor, we assume each ray results \
in a different number of intervals. In this case, `samples.vals` \
will has the shape of (all_samples,), the attributes `packed_info` and \
`ray_indices` will be accessable.
Example:
.. code-block:: python
>>> intervals = RayIntervals(
... vals=torch.tensor([0.0, 1.0, 0.0, 1.0, 2.0], device="cuda"),
... packed_info=torch.tensor([[0, 2], [2, 3]], device="cuda"),
... )
>>> cdfs = torch.tensor([0.0, 0.5, 0.0, 0.5, 1.0], device="cuda")
>>> n_intervals_per_ray = 2
>>> intervals, samples = importance_sampling(intervals, cdfs, n_intervals_per_ray)
>>> intervals.vals
tensor([[0.0000, 0.5000, 1.0000],
[0.0000, 1.0000, 2.0000]], device='cuda:0')
>>> samples.vals
tensor([[0.2500, 0.7500],
[0.5000, 1.5000]], device='cuda:0')
"""
if isinstance(n_intervals_per_ray, Tensor):
n_intervals_per_ray = n_intervals_per_ray.contiguous()
intervals, samples = _C.importance_sampling(
intervals._to_cpp(),
cdfs.contiguous(),
n_intervals_per_ray,
stratified,
)
return RayIntervals._from_cpp(intervals), RaySamples._from_cpp(samples)
def _sample_from_weighted(
bins: Tensor,
weights: Tensor,
masks: Optional[Tensor],
ts_query: Tensor,
masks_query: Optional[Tensor],
):
assert ts.dim() == weights.dim() == ts_query.dim() == 2
assert ts.shape[0] == weights.shape[0] == ts_query.shape[0]
assert ts.shape[1] == weights.shape[1] + 1
ts = ts.contiguous()
weights = weights.contiguous()
ts_query = ts_query.contiguous()
masks = masks.contiguous() if masks is not None else None
masks_query = (
masks_query.contiguous() if masks_query is not None else None
num_samples: int,
stratified: bool = False,
vmin: float = -torch.inf,
vmax: float = torch.inf,
) -> Tuple[Tensor, Tensor]:
import torch.nn.functional as F
"""
Args:
bins: (..., B + 1).
weights: (..., B).
Returns:
samples: (..., S + 1).
"""
B = weights.shape[-1]
S = num_samples
assert bins.shape[-1] == B + 1
dtype, device = bins.dtype, bins.device
eps = torch.finfo(weights.dtype).eps
# (..., B).
pdf = F.normalize(weights, p=1, dim=-1)
# (..., B + 1).
cdf = torch.cat(
[
torch.zeros_like(pdf[..., :1]),
torch.cumsum(pdf[..., :-1], dim=-1),
torch.ones_like(pdf[..., :1]),
],
dim=-1,
)
weights_query = _C.pdf_readout(
ts, weights, masks, ts_query, masks_query
# (..., S). Sample positions between [0, 1).
if not stratified:
pad = 1 / (2 * S)
# Get the center of each pdf bins.
u = torch.linspace(pad, 1 - pad - eps, S, dtype=dtype, device=device)
u = u.broadcast_to(bins.shape[:-1] + (S,))
else:
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
u_max = eps + (1 - eps) / S
max_jitter = (1 - u_max) / (S - 1) - eps
# Only perform one jittering per ray (`single_jitter` in the original
# implementation.)
u = (
torch.linspace(0, 1 - u_max, S, dtype=dtype, device=device)
+ torch.rand(
*bins.shape[:-1],
1,
dtype=dtype,
device=device,
)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(ts, masks, ts_query, masks_query)
return weights_query
@staticmethod
def backward(ctx, weights_query_grads: Tensor):
weights_query_grads = weights_query_grads.contiguous()
ts, masks, ts_query, masks_query = ctx.saved_tensors
weights_grads = _C.pdf_readout(
ts_query, weights_query_grads, masks_query, ts, masks
* max_jitter
)
return None, weights_grads, None, None, None
# (..., S).
ceil = torch.searchsorted(cdf.contiguous(), u.contiguous(), side="right")
floor = ceil - 1
# (..., S * 2).
inds = torch.cat([floor, ceil], dim=-1)
pdf_outer = PDFOuter.apply
# (..., S).
cdf0, cdf1 = cdf.gather(-1, inds).split(S, dim=-1)
b0, b1 = bins.gather(-1, inds).split(S, dim=-1)
# (..., S). Linear interpolation in 1D.
t = (u - cdf0) / torch.clamp(cdf1 - cdf0, min=eps)
# Sample centers.
centers = b0 + t * (b1 - b0)
@torch.no_grad()
def pdf_sampling(
t: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
padding: float = 0.01,
stratified: bool = False,
single_jitter: bool = False,
masks: Optional[torch.Tensor] = None,
):
assert t.shape[0] == weights.shape[0]
assert t.shape[1] == weights.shape[1] + 1
if masks is not None:
assert t.shape[0] == masks.shape[0]
t_new = _C.pdf_sampling(
t.contiguous(),
weights.contiguous(),
n_samples + 1, # be careful here!
padding,
stratified,
single_jitter,
masks.contiguous() if masks is not None else None,
samples = (centers[..., 1:] + centers[..., :-1]) / 2
samples = torch.cat(
[
(2 * centers[..., :1] - samples[..., :1]).clamp_min(vmin),
samples,
(2 * centers[..., -1:] - samples[..., -1:]).clamp_max(vmax),
],
dim=-1,
)
return t_new # [n_ray, n_samples+1]
return samples, centers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment