Unverified Commit 8340e19d authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

0.5.0: Rewrite all the underlying CUDA. Speedup and Benchmarking. (#182)

* importance_sampling with test

* package importance_sampling

* compute_intervals tested and packaged

* compute_intervals_v2

* bicycle is failing

* fix cut in compute_intervals_v2, test pass for rendering

* hacky way to get opaque_bkgd work

* reorg ING

* PackedRaySegmentsSpec

* chunk_ids -> ray_ids

* binary -> occupied

* test_traverse_grid_basic checked

* fix traverse_grid with step size, checked

* support max_step_size, not verified

* _cuda and cuda; upgrade ray_marching

* inclusive scan

* test_exclusive_sum but seems to have numeric error

* inclusive_sum_backward verified

* exclusive sum backward

* merge fwd and bwd for scan

* inclusive & exclusive prod verified

* support normal scan with torch funcs

* rendering and tests

* a bit clean up

* importance_sampling verified

* stratified for importance_sampling

* importance_sampling in pdf.py

* RaySegmentsSpec in data_specs; fix various bugs

* verified with _proposal_packed.py

* importance sampling support batch input/output. need to verify

* prop script with batch samples

* try to use cumsum  instead of cumprod

* searchsorted

* benchmarking prop

* ray_aabb_intersect untested

* update prop benchmark numbers

* minor fixes

* batched ray_aabb_intersect

* ray_aabb_intersect and traverse with grid(s)

* tiny optimize for traverse_grids kernels

* traverse_grids return intervals and samples

* cub not verified

* cleanup

* propnet and occgrid as estimators

* training print iters 10k

* prop is good now

* benchmark in google sheet.

* really cleanup: scan.py and test

* pack.py and test

* rendering and test

* data_specs.py and pdf.py docs

* data_specs.py and pdf.py docs

* init and headers

* grid.py and test for it

* occ grid docs

* generated docs

* example docs for pack and scan function.

* doc fix for volrend.py

* doc fix for pdf.py

* fix doc for rendering function

* docs

* propnet docs

* update scripts

* docs: index.rst

* methodology docs

* docs for examples

* mlp nerf script

* update t-nerf script

* rename dnerf to tnerf

* misc update

* bug fix: pdf_loss with test

* minor fix

* update readme with submodules

* fix format

* update gitingore file

* fix doc failure. teaser png to jpg

* docs in examples/
parent e547490c
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*
* Modified from
* https://github.com/pytorch/pytorch/blob/06a64f7eaa47ce430a3fa61016010075b59b18a7/aten/src/ATen/native/cuda/ScanUtils.cuh
*/
#include "utils_cuda.cuh"
// CUB support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
namespace {
namespace device {
/* Perform an inclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void inclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size) {
row_tgt[col] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
inclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
inclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
/* Perform an exclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void exclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
row_tgt[0] = init;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size - 1) row_tgt[col1 + 1] = row_buf[threadIdx.x];
if (col2 < row_size - 1) row_tgt[col2 + 1] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size - 1) {
row_tgt[col + 1] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
exclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
exclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
} // namespace device
} // namespace
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *near,
scalar_t *far)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax)
_swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax)
_swap(tymin, tymax);
if (tmin > tymax || tymin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tymin > tmin)
tmin = tymin;
if (tymax < tmax)
tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax)
_swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tzmin > tmin)
tmin = tzmin;
if (tzmax < tmax)
tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void ray_aabb_intersect_kernel(
const int N,
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *t_min,
scalar_t *t_max)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
// locate
rays_o += thread_id * 3;
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
}
/**
* @brief Ray AABB Test
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero. 1e10 is returned if no intersection.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(aabb);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(N, threads);
torch::Tensor t_min = torch::empty({N}, rays_o.options());
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&]
{ ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()); }));
return {t_min, t_max};
}
\ No newline at end of file
// This file contains only Python bindings
#include "include/data_spec.hpp"
#include <torch/extension.h>
bool is_cub_available() {
// FIXME: why return false?
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
// scan
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
// grid
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor aabbs, // [n_aabbs, 6]
const float near_plane,
const float far_plane,
const float miss_value);
std::vector<RaySegmentsSpec> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples);
// pdf
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
torch::Tensor n_intervels_per_ray,
bool stratified);
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
int64_t n_intervels_per_ray,
bool stratified);
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function
_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);
#undef _REG_FUNC
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void unpack_info_kernel(
// input
const int n_rays,
const int *packed_info,
// output
int64_t *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
ray_indices += base;
for (int j = 0; j < steps; ++j)
{
ray_indices[j] = i;
}
}
__global__ void unpack_info_to_mask_kernel(
// input
const int n_rays,
const int *packed_info,
const int n_samples,
// output
bool *masks) // [n_rays, n_samples]
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
masks += i * n_samples;
for (int j = 0; j < steps; ++j)
{
masks[j] = true;
}
}
template <typename scalar_t>
__global__ void unpack_data_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const int data_dim,
const scalar_t *data,
const int n_sampler_per_ray,
scalar_t *unpacked_data) // (n_rays, n_sampler_per_ray, data_dim)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
data += base * data_dim;
unpacked_data += i * n_sampler_per_ray * data_dim;
for (int j = 0; j < steps; j++)
{
for (int k = 0; k < data_dim; k++)
{
unpacked_data[j * data_dim + k] = data[j * data_dim + k];
}
}
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int64_t>());
return ray_indices;
}
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor masks = torch::zeros(
{n_rays, n_samples}, packed_info.options().dtype(torch::kBool));
unpack_info_to_mask_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
n_samples,
masks.data_ptr<bool>());
return masks;
}
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(data);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(data.ndimension() == 2);
const int n_rays = packed_info.size(0);
const int n_samples = data.size(0);
const int data_dim = data.size(1);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::full(
{n_rays, n_samples_per_ray, data_dim}, pad_value, data.options());
AT_DISPATCH_ALL_TYPES(
data.scalar_type(),
"unpack_data",
([&]
{ unpack_data_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
data_dim,
data.data_ptr<scalar_t>(),
n_samples_per_ray,
// outputs
unpacked_data.data_ptr<scalar_t>()); }));
return unpacked_data;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <torch/extension.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
......@@ -11,15 +9,16 @@
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include "include/helpers_cuda.h"
#include "include/data_spec.hpp"
#include "include/data_spec_packed.cuh"
#include "include/utils_cuda.cuh"
#include "include/utils_grid.cuh"
#include "include/utils_math.cuh"
namespace F = torch::nn::functional;
static constexpr uint32_t MAX_GRID_LEVELS = 8;
template <typename scalar_t>
inline __device__ __host__ scalar_t ceil_div(scalar_t a, scalar_t b)
{
return (a + b - 1) / b;
}
namespace {
namespace device {
// Taken from:
// https://github.com/pytorch/pytorch/blob/8f1c3c68d3aba5c8898bfb3144988aab6776d549/aten/src/ATen/native/cuda/Bucketization.cu
......@@ -63,291 +62,395 @@ __device__ int64_t upper_bound(const scalar_t *data_ss, int64_t start, int64_t e
return start;
}
template <typename scalar_t>
__global__ void pdf_sampling_kernel(
at::PhiloxCudaState philox_args,
const int64_t n_samples_in, // n_samples_in or whatever (not used)
const int64_t *info_ts, // nullptr or [n_rays, 2]
const scalar_t *ts, // [n_rays, n_samples_in] or packed [all_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in] or packed [all_samples_in]
const bool *masks, // [n_rays]
const bool stratified,
const bool single_jitter,
// outputs
const int64_t numel,
const int64_t n_samples_out,
scalar_t *ts_out) // [n_rays, n_samples_out]
inline __device__ int32_t binary_search_chunk_id(
const int64_t item_id,
const int32_t n_chunks,
const int64_t *chunk_starts)
{
int64_t n_bins_out = n_samples_out - 1;
scalar_t u_pad, u_interval;
if (stratified) {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
} else {
u_interval = 1.0f / n_bins_out;
u_pad = 0.0f;
}
// = stratified ? 1.0f / n_samples_out : (1.0f - 2 * pad) / (n_samples_out - 1);
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
{
int64_t ray_id = tid / n_samples_out;
int64_t sample_id = tid % n_samples_out;
if (masks != nullptr && !masks[ray_id]) {
// This ray is to be skipped.
// Be careful the ts needs to be initialized properly.
continue;
}
int64_t start_bd, end_bd;
if (info_ts == nullptr)
int32_t start = 0;
int32_t end = n_chunks;
while (start < end)
{
// no packing, the input is already [n_rays, n_samples_in]
start_bd = ray_id * n_samples_in;
end_bd = start_bd + n_samples_in;
const int32_t mid = start + ((end - start) >> 1);
const int64_t mid_val = chunk_starts[mid];
if (!(mid_val > item_id)) start = mid + 1;
else end = mid;
}
else
return start;
}
/* kernels for importance_sampling */
__global__ void compute_ray_ids_kernel(
const int64_t n_rays,
const int64_t n_items,
const int64_t *chunk_starts,
// outputs
int64_t *ray_indices)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_items; tid += blockDim.x * gridDim.x)
{
// packed input, the input is [all_samples_in]
start_bd = info_ts[ray_id * 2];
end_bd = start_bd + info_ts[ray_id * 2 + 1];
if (start_bd == end_bd) {
// This ray is empty, so there is nothing to sample from.
// Be careful the ts needs to be initialized properly.
continue;
}
ray_indices[tid] = binary_search_chunk_id(tid, n_rays, chunk_starts) - 1;
}
}
scalar_t u = u_pad + sample_id * u_interval;
if (stratified)
__global__ void importance_sampling_kernel(
// cdfs
PackedRaySegmentsSpec ray_segments,
const float *cdfs,
// jittering
bool stratified,
at::PhiloxCudaState philox_args,
// outputs
PackedRaySegmentsSpec samples)
{
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < samples.n_edges; tid += blockDim.x * gridDim.x)
{
auto seeds = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
int64_t rand_seq_id = single_jitter ? ray_id : tid;
curand_init(std::get<0>(seeds), rand_seq_id, std::get<1>(seeds), &state);
float rand = curand_uniform(&state);
u -= rand * u_interval;
u = max(u, static_cast<scalar_t>(0.0f));
int32_t ray_id;
int64_t n_samples, sid;
if (samples.is_batched) {
ray_id = tid / samples.n_edges_per_ray;
n_samples = samples.n_edges_per_ray;
sid = tid - ray_id * samples.n_edges_per_ray;
} else {
ray_id = binary_search_chunk_id(tid, samples.n_rays, samples.chunk_starts) - 1;
samples.ray_indices[tid] = ray_id;
n_samples = samples.chunk_cnts[ray_id];
sid = tid - samples.chunk_starts[ray_id];
}
int64_t base, last;
if (ray_segments.is_batched) {
base = ray_id * ray_segments.n_edges_per_ray;
last = base + ray_segments.n_edges_per_ray - 1;
} else {
base = ray_segments.chunk_starts[ray_id];
last = base + ray_segments.chunk_cnts[ray_id] - 1;
}
float u_floor = cdfs[base];
float u_ceil = cdfs[last];
float u_step = (u_ceil - u_floor) / n_samples;
float bias = 0.5f;
if (stratified) {
auto seeds = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds), ray_id, std::get<1>(seeds), &state);
bias = curand_uniform(&state);
}
float u = u_floor + (sid + bias) * u_step;
// searchsorted with "right" option:
// i.e. cdfs[p - 1] <= u < cdfs[p]
int64_t p = upper_bound<float>(cdfs, base, last, u, nullptr);
int64_t p0 = max(min(p - 1, last), base);
int64_t p1 = max(min(p, last), base);
float u_lower = cdfs[p0];
float u_upper = cdfs[p1];
float t_lower = ray_segments.vals[p0];
float t_upper = ray_segments.vals[p1];
float t;
if (u_upper - u_lower < 1e-10f) {
t = (t_lower + t_upper) * 0.5f;
} else {
float scaling = (t_upper - t_lower) / (u_upper - u_lower);
t = (u - u_lower) * scaling + t_lower;
}
samples.vals[tid] = t;
}
}
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos = upper_bound<scalar_t>(accum_weights, start_bd, end_bd, u, nullptr);
int64_t p0 = min(max(pos - 1, start_bd), end_bd - 1);
int64_t p1 = min(max(pos, start_bd), end_bd - 1);
scalar_t start_u = accum_weights[p0];
scalar_t end_u = accum_weights[p1];
scalar_t start_t = ts[p0];
scalar_t end_t = ts[p1];
if (p0 == p1) {
if (p0 == end_bd - 1)
ts_out[tid] = end_t;
else
ts_out[tid] = start_t;
} else if (end_u - start_u < 1e-20f) {
ts_out[tid] = (start_t + end_t) * 0.5f;
} else {
scalar_t scaling = (end_t - start_t) / (end_u - start_u);
scalar_t t = (u - start_u) * scaling + start_t;
ts_out[tid] = t;
__global__ void compute_intervels_kernel(
PackedRaySegmentsSpec ray_segments,
PackedRaySegmentsSpec samples,
// outputs
PackedRaySegmentsSpec intervals)
{
// parallelize over samples
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < samples.n_edges; tid += blockDim.x * gridDim.x)
{
int32_t ray_id;
int64_t n_samples, sid;
if (samples.is_batched) {
ray_id = tid / samples.n_edges_per_ray;
n_samples = samples.n_edges_per_ray;
sid = tid - ray_id * samples.n_edges_per_ray;
} else {
ray_id = samples.ray_indices[tid];
n_samples = samples.chunk_cnts[ray_id];
sid = tid - samples.chunk_starts[ray_id];
}
int64_t base, last;
if (ray_segments.is_batched) {
base = ray_id * ray_segments.n_edges_per_ray;
last = base + ray_segments.n_edges_per_ray - 1;
} else {
base = ray_segments.chunk_starts[ray_id];
last = base + ray_segments.chunk_cnts[ray_id] - 1;
}
int64_t base_out;
if (intervals.is_batched) {
base_out = ray_id * intervals.n_edges_per_ray;
} else {
base_out = intervals.chunk_starts[ray_id];
}
float t_min = ray_segments.vals[base];
float t_max = ray_segments.vals[last];
if (sid == 0) {
float t = samples.vals[tid];
float t_next = samples.vals[tid + 1]; // FIXME: out of bounds?
float half_width = (t_next - t) * 0.5f;
intervals.vals[base_out] = fmaxf(t - half_width, t_min);
if (!intervals.is_batched) {
intervals.ray_indices[base_out] = ray_id;
intervals.is_left[base_out] = true;
intervals.is_right[base_out] = false;
}
} else {
float t = samples.vals[tid];
float t_prev = samples.vals[tid - 1];
float t_edge = (t + t_prev) * 0.5f;
int64_t idx = base_out + sid;
intervals.vals[idx] = t_edge;
if (!intervals.is_batched) {
intervals.ray_indices[idx] = ray_id;
intervals.is_left[idx] = true;
intervals.is_right[idx] = true;
}
if (sid == n_samples - 1) {
float half_width = (t - t_prev) * 0.5f;
intervals.vals[idx + 1] = fminf(t + half_width, t_max);
if (!intervals.is_batched) {
intervals.ray_indices[idx + 1] = ray_id;
intervals.is_left[idx + 1] = false;
intervals.is_right[idx + 1] = true;
}
}
}
}
}
}
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt) // [n_rays]
/* kernels for searchsorted */
__global__ void searchsorted_kernel(
PackedRaySegmentsSpec query,
PackedRaySegmentsSpec key,
// outputs
int64_t *ids_left,
int64_t *ids_right)
{
DEVICE_GUARD(ts);
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < query.n_edges; tid += blockDim.x * gridDim.x)
{
int32_t ray_id;
if (query.is_batched) {
ray_id = tid / query.n_edges_per_ray;
} else {
if (query.ray_indices == nullptr) {
ray_id = binary_search_chunk_id(tid, query.n_rays, query.chunk_starts) - 1;
} else {
ray_id = query.ray_indices[tid];
}
}
int64_t base, last;
if (key.is_batched) {
base = ray_id * key.n_edges_per_ray;
last = base + key.n_edges_per_ray - 1;
} else {
base = key.chunk_starts[ray_id];
last = base + key.chunk_cnts[ray_id] - 1;
}
// searchsorted with "right" option:
// i.e. key.vals[p - 1] <= query.vals[tid] < key.vals[p]
int64_t p = upper_bound<float>(key.vals, base, last, query.vals[tid], nullptr);
if (query.is_batched) {
ids_left[tid] = max(min(p - 1, last), base) - base;
ids_right[tid] = max(min(p, last), base) - base;
} else {
ids_left[tid] = max(min(p - 1, last), base);
ids_right[tid] = max(min(p, last), base);
}
}
}
CHECK_INPUT(ts);
CHECK_INPUT(weights);
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
TORCH_CHECK(ts.size(1) == weights.size(1) + 1);
} // namespace device
} // namespace
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
if (padding > 0.f)
{
weights = weights + padding;
}
weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
torch::Tensor ts_out = torch::full({ts.size(0), n_samples}, -1.0f, ts.options());
int64_t numel = ts_out.numel();
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(4);
}
// Return flattend RaySegmentsSpec because n_intervels_per_ray is defined per ray.
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments, // [..., n_edges_per_ray] or flattend
torch::Tensor cdfs, // [..., n_edges_per_ray] or flattend
torch::Tensor n_intervels_per_ray, // [...] or flattend
bool stratified)
{
DEVICE_GUARD(cdfs);
ray_segments.check();
CHECK_INPUT(cdfs);
CHECK_INPUT(n_intervels_per_ray);
TORCH_CHECK(cdfs.numel() == ray_segments.vals.numel());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads, blocks;
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(4);
}
AT_DISPATCH_ALL_TYPES(
ts.scalar_type(),
"pdf_sampling",
([&]
{ pdf_sampling_kernel<scalar_t><<<grid, block, 0, stream>>>(
rng_engine_inputs,
ts.size(1), /* n_samples_in */
nullptr, /* info_ts */
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr, /* masks */
stratified,
single_jitter,
numel, /* numel */
ts_out.size(1), /* n_samples_out */
ts_out.data_ptr<scalar_t>() /* ts_out */
); }));
return ts_out; // [n_rays, n_samples_out]
// output samples
RaySegmentsSpec samples;
samples.chunk_cnts = n_intervels_per_ray.to(n_intervels_per_ray.options().dtype(torch::kLong));
samples.memalloc_data(false, false); // no need boolen masks, no need to zero init.
int64_t n_samples = samples.vals.numel();
// step 1. compute the ray_indices and samples
threads = dim3(min(max_threads, n_samples));
blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_samples, threads.x)));
device::importance_sampling_kernel<<<blocks, threads, 0, stream>>>(
// cdfs
device::PackedRaySegmentsSpec(ray_segments),
cdfs.data_ptr<float>(),
// jittering
stratified,
rng_engine_inputs,
// output samples
device::PackedRaySegmentsSpec(samples));
// output ray segments
RaySegmentsSpec intervals;
intervals.chunk_cnts = (
(samples.chunk_cnts + 1) * (samples.chunk_cnts > 0)).to(samples.chunk_cnts.options());
intervals.memalloc_data(true, true); // need the boolen masks, need to zero init.
// step 2. compute the intervals.
device::compute_intervels_kernel<<<blocks, threads, 0, stream>>>(
// samples
device::PackedRaySegmentsSpec(ray_segments),
device::PackedRaySegmentsSpec(samples),
// output intervals
device::PackedRaySegmentsSpec(intervals));
return {intervals, samples};
}
template <typename scalar_t>
__global__ void pdf_readout_kernel(
const int64_t n_rays,
// keys
const int64_t n_samples_in,
const scalar_t *ts, // [n_rays, n_samples_in]
const scalar_t *accum_weights, // [n_rays, n_samples_in]
const bool *masks, // [n_rays]
// query
const int64_t n_samples_out,
const scalar_t *ts_out, // [n_rays, n_samples_out]
const bool *masks_out, // [n_rays]
scalar_t *weights_out) // [n_rays, n_samples_out - 1]
// Return batched RaySegmentsSpec because n_intervels_per_ray is same across rays.
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments, // [..., n_edges_per_ray] or flattend
torch::Tensor cdfs, // [..., n_edges_per_ray] or flattend
int64_t n_intervels_per_ray,
bool stratified)
{
int64_t n_bins_out = n_samples_out - 1;
int64_t numel = n_bins_out * n_rays;
// parallelize over outputs
for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += blockDim.x * gridDim.x)
{
int64_t ray_id = tid / n_bins_out;
if (masks_out != nullptr && !masks_out[ray_id]) {
// We don't care about this query ray.
weights_out[tid] = 0.0f;
continue;
}
if (masks != nullptr && !masks[ray_id]) {
// We don't have the values for the key ray. In this case we consider the key ray
// is all-zero.
weights_out[tid] = 0.0f;
continue;
DEVICE_GUARD(cdfs);
ray_segments.check();
CHECK_INPUT(cdfs);
TORCH_CHECK(cdfs.numel() == ray_segments.vals.numel());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads, blocks;
// For jittering
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(4);
}
// search range in ts
int64_t start_bd = ray_id * n_samples_in;
int64_t end_bd = start_bd + n_samples_in;
// index in ts_out
int64_t id0 = tid + ray_id;
int64_t id1 = id0 + 1;
RaySegmentsSpec samples, intervals;
if (ray_segments.vals.ndimension() > 1){ // batched input
auto data_size = ray_segments.vals.sizes().vec();
data_size.back() = n_intervels_per_ray;
samples.vals = torch::empty(data_size, cdfs.options());
data_size.back() = n_intervels_per_ray + 1;
intervals.vals = torch::empty(data_size, cdfs.options());
} else { // flattend input
int64_t n_rays = ray_segments.chunk_cnts.numel();
samples.vals = torch::empty({n_rays, n_intervels_per_ray}, cdfs.options());
intervals.vals = torch::empty({n_rays, n_intervels_per_ray + 1}, cdfs.options());
}
int64_t n_samples = samples.vals.numel();
// searchsorted with "right" option:
// i.e. accum_weights[pos - 1] <= u < accum_weights[pos]
int64_t pos0 = upper_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id0], nullptr);
pos0 = max(min(pos0 - 1, end_bd-1), start_bd);
// searchsorted with "left" option:
// i.e. accum_weights[pos - 1] < u <= accum_weights[pos]
int64_t pos1 = lower_bound<scalar_t>(ts, start_bd, end_bd, ts_out[id1], nullptr);
pos1 = max(min(pos1, end_bd-1), start_bd);
// outer
scalar_t outer = accum_weights[pos1] - accum_weights[pos0];
weights_out[tid] = outer;
}
// step 1. compute the ray_indices and samples
threads = dim3(min(max_threads, n_samples));
blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_samples, threads.x)));
device::importance_sampling_kernel<<<blocks, threads, 0, stream>>>(
// cdfs
device::PackedRaySegmentsSpec(ray_segments),
cdfs.data_ptr<float>(),
// jittering
stratified,
rng_engine_inputs,
// output samples
device::PackedRaySegmentsSpec(samples));
// step 2. compute the intervals.
device::compute_intervels_kernel<<<blocks, threads, 0, stream>>>(
// samples
device::PackedRaySegmentsSpec(ray_segments),
device::PackedRaySegmentsSpec(samples),
// output intervals
device::PackedRaySegmentsSpec(intervals));
return {intervals, samples};
}
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt) // [n_rays, n_samples_out]
// Find two indices {left, right} for each item in query,
// such that: key.vals[left] <= query.vals < key.vals[right]
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key)
{
DEVICE_GUARD(ts);
CHECK_INPUT(ts);
CHECK_INPUT(weights);
TORCH_CHECK(ts.ndimension() == 2);
TORCH_CHECK(weights.ndimension() == 2);
int64_t n_rays = ts.size(0);
int64_t n_samples_in = ts.size(1);
int64_t n_samples_out = ts_out.size(1);
int64_t n_bins_out = n_samples_out - 1;
c10::MaybeOwned<torch::Tensor> masks_maybe_owned = at::borrow_from_optional_tensor(masks_opt);
const torch::Tensor& masks = *masks_maybe_owned;
c10::MaybeOwned<torch::Tensor> masks_out_maybe_owned = at::borrow_from_optional_tensor(masks_out_opt);
const torch::Tensor& masks_out = *masks_out_maybe_owned;
// weights = F::normalize(weights, F::NormalizeFuncOptions().p(1).dim(-1));
torch::Tensor accum_weights = torch::cat({torch::zeros({weights.size(0), 1}, weights.options()),
weights.cumsum(1, weights.scalar_type())},
1);
torch::Tensor weights_out = torch::empty({n_rays, n_bins_out}, weights.options());
int64_t maxThread = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t maxGrid = 1024;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t numel = weights_out.numel();
dim3 block = dim3(min(maxThread, numel));
dim3 grid = dim3(min(maxGrid, ceil_div<int64_t>(numel, block.x)));
AT_DISPATCH_ALL_TYPES(
weights.scalar_type(),
"pdf_readout",
([&]
{ pdf_readout_kernel<scalar_t><<<grid, block, 0, stream>>>(
n_rays,
n_samples_in,
ts.data_ptr<scalar_t>(), /* ts */
accum_weights.data_ptr<scalar_t>(), /* accum_weights */
masks.defined() ? masks.data_ptr<bool>() : nullptr,
n_samples_out,
ts_out.data_ptr<scalar_t>(), /* ts_out */
masks_out.defined() ? masks_out.data_ptr<bool>() : nullptr,
weights_out.data_ptr<scalar_t>()
); }));
return weights_out; // [n_rays, n_bins_out]
DEVICE_GUARD(query.vals);
query.check();
key.check();
// outputs
int64_t n_edges = query.vals.numel();
torch::Tensor ids_left = torch::empty(
query.vals.sizes(), query.vals.options().dtype(torch::kLong));
torch::Tensor ids_right = torch::empty(
query.vals.sizes(), query.vals.options().dtype(torch::kLong));
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int64_t max_threads = 512; // at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, n_edges));
dim3 blocks = dim3(min(max_blocks, ceil_div<int64_t>(n_edges, threads.x)));
device::searchsorted_kernel<<<blocks, threads, 0, stream>>>(
device::PackedRaySegmentsSpec(query),
device::PackedRaySegmentsSpec(key),
// outputs
ids_left.data_ptr<int64_t>(),
ids_right.data_ptr<int64_t>());
return {ids_left, ids_right};
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type);
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value);
// cub implementations: parallel across samples
bool is_cub_available() {
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
// naive implementations: parallel across rays
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas);
// pdf
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt); // [n_rays]
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt); // [n_rays, n_samples_out]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// contraction
py::enum_<ContractionType>(m, "ContractionType")
.value("AABB", ContractionType::AABB)
.value("UN_BOUNDED_TANH", ContractionType::UN_BOUNDED_TANH)
.value("UN_BOUNDED_SPHERE", ContractionType::UN_BOUNDED_SPHERE);
m.def("contract", &contract);
m.def("contract_inv", &contract_inv);
// grid
m.def("grid_query", &grid_query);
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
// rendering
m.def("is_cub_available", is_cub_available);
m.def("transmittance_from_sigma_forward_cub", transmittance_from_sigma_forward_cub);
m.def("transmittance_from_sigma_backward_cub", transmittance_from_sigma_backward_cub);
m.def("transmittance_from_alpha_forward_cub", transmittance_from_alpha_forward_cub);
m.def("transmittance_from_alpha_backward_cub", transmittance_from_alpha_backward_cub);
m.def("transmittance_from_sigma_forward_naive", transmittance_from_sigma_forward_naive);
m.def("transmittance_from_sigma_backward_naive", transmittance_from_sigma_backward_naive);
m.def("transmittance_from_alpha_forward_naive", transmittance_from_alpha_forward_naive);
m.def("transmittance_from_alpha_backward_naive", transmittance_from_alpha_backward_naive);
m.def("weight_from_sigma_forward_naive", weight_from_sigma_forward_naive);
m.def("weight_from_sigma_backward_naive", weight_from_sigma_backward_naive);
m.def("weight_from_alpha_forward_naive", weight_from_alpha_forward_naive);
m.def("weight_from_alpha_backward_naive", weight_from_alpha_backward_naive);
// pack & unpack
m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask);
// pdf
m.def("pdf_sampling", &pdf_sampling);
m.def("pdf_readout", &pdf_readout);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
inline __device__ __host__ float calc_dt(
const float t, const float cone_angle,
const float dt_min, const float dt_max)
{
return clamp(t * cone_angle, dt_min, dt_max);
}
inline __device__ __host__ int mip_level(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
if (type != ContractionType::AABB)
{
// mip level should be always zero if not using AABB
return 0;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, ContractionType::AABB);
float3 scale = fabs(xyz_unit - 0.5);
float maxval = fmaxf(fmaxf(scale.x, scale.y), scale.z);
// if maxval is almost zero, it will trigger frexpf to output 0
// for exponent, which is not what we want.
maxval = fmaxf(maxval, 0.1);
int exponent;
frexpf(maxval, &exponent);
int mip = max(0, exponent + 1);
return mip;
}
inline __device__ __host__ int grid_idx_at(
const float3 xyz_unit, const int3 grid_res)
{
// xyz should be always in [0, 1]^3.
int3 ixyz = make_int3(xyz_unit * make_float3(grid_res));
ixyz = clamp(ixyz, make_int3(0, 0, 0), grid_res - 1);
int3 grid_offset = make_int3(grid_res.y * grid_res.z, grid_res.z, 1);
int idx = dot(ixyz, grid_offset);
return idx;
}
template <typename scalar_t>
inline __device__ __host__ scalar_t grid_occupied_at(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
ContractionType type, int mip,
const int grid_nlvl, const int3 grid_res, const scalar_t *grid_value)
{
if (type == ContractionType::AABB && mip >= grid_nlvl)
{
return false;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type);
xyz_unit = (xyz_unit - 0.5) * scalbnf(1.0f, -mip) + 0.5;
int idx = grid_idx_at(xyz_unit, grid_res) + mip * grid_res.x * grid_res.y * grid_res.z;
return grid_value[idx];
}
// dda like step
inline __device__ __host__ float distance_to_next_voxel(
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res)
{
float scaling = scalbnf(1.0f, mip);
float3 _roi_mid = (roi_min + roi_max) * 0.5;
float3 _roi_rad = (roi_max - roi_min) * 0.5;
float3 _roi_min = _roi_mid - _roi_rad * scaling;
float3 _roi_max = _roi_mid + _roi_rad * scaling;
float3 _occ_res = make_float3(grid_res);
float3 _xyz = roi_to_unit(xyz, _roi_min, _roi_max) * _occ_res;
float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (_roi_max - _roi_min);
float t = min(min(txyz.x, txyz.y), txyz.z);
return fmaxf(t, 0.0f);
}
inline __device__ __host__ float advance_to_next_voxel(
const float t, const float dt_min,
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res, const float far)
{
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res);
t_target = min(t_target, far);
float _t = t;
do
{
_t += dt_min;
} while (_t < t_target);
return _t;
}
// -------------------------------------------------------------------------------
// Raymarching
// -------------------------------------------------------------------------------
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int64_t *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float grid_cell_scale = fmaxf(fmaxf(
(roi_max.x - roi_min.x) / grid_res.x * scalbnf(1.732f, grid_nlvl - 1),
(roi_max.y - roi_min.y) / grid_res.y * scalbnf(1.732f, grid_nlvl - 1)),
(roi_max.z - roi_min.z) / grid_res.z * scalbnf(1.732f, grid_nlvl - 1));
// TODO: compute dt_max from occ resolution.
float dt_min = step_size;
float dt_max;
if (type == ContractionType::AABB) {
// compute dt_max from occ grid resolution.
dt_max = grid_cell_scale;
} else {
dt_max = 1e10f;
}
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
// current center
const float3 xyz = origin + t_mid * dir;
// current mip level
const int mip = mip_level(xyz, roi_min, roi_max, type);
if (mip >= grid_nlvl) {
// out of grid
break;
}
if (grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_binary))
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
else
{
// march to next sample
switch (type)
{
case ContractionType::AABB:
// no contraction
t_mid = advance_to_next_voxel(
t_mid, dt_min, xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res, far);
dt = calc_dt(t_mid, cone_angle, dt_min, dt_max);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
break;
default:
// any type of scene contraction does not work with DDA.
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
break;
}
}
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(roi);
CHECK_INPUT(grid_binary);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6)
TORCH_CHECK(grid_binary.ndimension() == 4)
const int n_rays = rays_o.size(0);
const int grid_nlvl = grid_binary.size(0);
const int3 grid_res = make_int3(
grid_binary.size(1), grid_binary.size(2), grid_binary.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options().dtype(torch::kLong));
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int64_t>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
// ----------------------------------------------------------------------------
// Query the occupancy grid
// ----------------------------------------------------------------------------
template <typename scalar_t>
__global__ void query_occ_kernel(
// rays info
const uint32_t n_samples,
const float *samples, // shape (n_samples, 3)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// outputs
scalar_t *occs)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
const int mip = mip_level(xyz, roi_min, roi_max, type);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_value);
return;
}
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int grid_nlvl = grid_value.size(0);
const int3 grid_res = make_int3(
grid_value.size(1), grid_value.size(2), grid_value.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::empty({n_samples}, grid_value.options());
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Bool,
occs.scalar_type(),
"grid_query",
([&]
{ query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// grid
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_value.data_ptr<scalar_t>(),
type,
// outputs
occs.data_ptr<scalar_t>()); }));
return occs;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void transmittance_from_sigma_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
transmittance += base;
// accumulation
float cumsum = 0.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = __expf(-cumsum);
cumsum += sigmas[j] * (ends[j] - starts[j]);
}
// // another way to impl:
// float T = 1.f;
// for (int j = 0; j < steps; ++j)
// {
// const float delta = ends[j] - starts[j];
// const float alpha = 1.f - __expf(-sigmas[j] * delta);
// transmittance[j] = T;
// T *= (1.f - alpha);
// }
return;
}
__global__ void transmittance_from_sigma_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *sigmas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
transmittance += base;
transmittance_grad += base;
starts += base;
ends += base;
sigmas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
sigmas_grad[j] = cumsum * (ends[j] - starts[j]);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
__global__ void transmittance_from_alpha_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
// accumulation
float T = 1.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = T;
T *= (1.0f - alphas[j]);
}
return;
}
__global__ void transmittance_from_alpha_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *alphas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
transmittance_grad += base;
alphas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
alphas_grad[j] = cumsum / fmax(1.0f - alphas[j], 1e-10f);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(sigmas);
// parallel across rays
transmittance_from_sigma_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor sigmas_grad = torch::empty_like(transmittance);
// parallel across rays
transmittance_from_sigma_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
sigmas_grad.data_ptr<float>());
return sigmas_grad;
}
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info, torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(packed_info.ndimension() == 2);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor alphas_grad = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
alphas_grad.data_ptr<float>());
return alphas_grad;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
// CUB is supported in CUDA >= 11.0
// ExclusiveScanByKey is supported in CUB >= 1.15.0 (CUDA >= 11.6)
// See: https://github.com/NVIDIA/cub/tree/main#releases
#include "include/helpers_cuda.h"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
#endif
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
// parallel across samples
torch::Tensor sigmas_dt = sigmas * (ends - starts);
torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
ray_indices.data_ptr<int64_t>(),
sigmas_dt.data_ptr<float>(),
sigmas_dt_cumsum.data_ptr<float>(),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor transmittance = (-sigmas_dt_cumsum).exp();
return transmittance;
}
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
// parallel across samples
torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance;
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor sigmas_grad = sigmas_dt_grad * (ends - starts);
return sigmas_grad;
}
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(alphas);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(ray_indices.ndimension() == 1);
const uint32_t n_samples = alphas.size(0);
// parallel across samples
torch::Tensor transmittance = torch::empty_like(alphas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
ray_indices.data_ptr<int64_t>(),
(1.0f - alphas).data_ptr<float>(),
transmittance.data_ptr<float>(),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
return transmittance;
}
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(ray_indices);
CHECK_INPUT(ray_indices);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(ray_indices.ndimension() == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
// parallel across samples
torch::Tensor sigmas_dt_cumsum_grad = -transmittance_grad * transmittance;
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
torch::Tensor alphas_grad = sigmas_dt_grad / (1.0f - alphas).clamp_min(1e-10f);
return alphas_grad;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void weight_from_sigma_forward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
// outputs
float *weights)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
weights += base;
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float delta = ends[j] - starts[j];
const float alpha = 1.f - __expf(-sigmas[j] * delta);
weights[j] = alpha * T;
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_sigma_backward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
const float *weights,
const float *grad_weights,
// outputs
float *grad_sigmas)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
float accum = 0;
for (int j = 0; j < steps; ++j)
{
accum += grad_weights[j] * weights[j];
}
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float delta = ends[j] - starts[j];
const float alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_alpha_forward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *alphas,
// outputs
float *weights)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
weights += base;
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float alpha = alphas[j];
weights[j] = alpha * T;
T *= (1.f - alpha);
}
return;
}
__global__ void weight_from_alpha_backward_kernel(
const uint32_t n_rays,
const int *packed_info,
const float *alphas,
const float *weights,
const float *grad_weights,
// outputs
float *grad_alphas)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
weights += base;
grad_weights += base;
grad_alphas += base;
float accum = 0;
for (int j = 0; j < steps; ++j)
{
accum += grad_weights[j] * weights[j];
}
// accumulation
float T = 1.f;
for (int j = 0; j < steps; ++j)
{
const float alpha = alphas[j];
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
return;
}
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::empty_like(sigmas);
weight_from_sigma_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
// outputs
weights.data_ptr<float>());
return weights;
}
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(weights);
CHECK_INPUT(grad_weights);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::empty_like(sigmas);
weight_from_sigma_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
weights.data_ptr<float>(),
grad_weights.data_ptr<float>(),
// outputs
grad_sigmas.data_ptr<float>());
return grad_sigmas;
}
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info, torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::empty_like(alphas);
weight_from_alpha_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
// outputs
weights.data_ptr<float>());
return weights;
}
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
CHECK_INPUT(weights);
CHECK_INPUT(grad_weights);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
TORCH_CHECK(grad_weights.ndimension() == 2 & grad_weights.size(1) == 1);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_alphas = torch::empty_like(alphas);
weight_from_alpha_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
weights.data_ptr<float>(),
grad_weights.data_ptr<float>(),
// outputs
grad_alphas.data_ptr<float>());
return grad_alphas;
}
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <thrust/iterator/reverse_iterator.h>
#include "include/utils_scan.cuh"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
#endif
namespace {
namespace device {
#if CUB_SUPPORTS_SCAN_BY_KEY()
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
} // namespace device
} // namespace
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward)
{
DEVICE_GUARD(inputs);
torch::Tensor outputs = torch::empty_like(inputs);
int64_t n_items = inputs.size(0);
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward)
device::exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_items),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_items),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_items),
n_items);
else
device::exclusive_sum_by_key(
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_items);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
if (backward)
TORCH_CHECK(~normalize); // backward does not support normalize yet.
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
normalize);
} else {
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
0.f,
std::plus<float>(),
normalize);
}
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
if (backward)
TORCH_CHECK(~normalize); // backward does not support normalize yet.
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
normalize);
} else {
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
0.f,
std::plus<float>(),
normalize);
}
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
1.f,
std::multiplies<float>(),
false);
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
false);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
cudaGetLastError();
return grad_inputs;
}
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(inputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
n_rays,
chunk_starts.data_ptr<int64_t>(),
chunk_cnts.data_ptr<int64_t>(),
1.f,
std::multiplies<float>(),
false);
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(chunk_starts);
CHECK_INPUT(chunk_cnts);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(chunk_starts.ndimension() == 1);
TORCH_CHECK(chunk_cnts.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(chunk_starts.size(0) == chunk_cnts.size(0));
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
n_rays,
thrust::make_reverse_iterator(chunk_starts.data_ptr<int64_t>() + n_rays),
thrust::make_reverse_iterator(chunk_cnts.data_ptr<int64_t>() + n_rays),
0.f,
std::plus<float>(),
false);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
cudaGetLastError();
return grad_inputs;
}
\ No newline at end of file
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from dataclasses import dataclass
from typing import Optional
import torch
from . import cuda as _C
@dataclass
class RaySamples:
"""Ray samples that supports batched and flattened data.
Note:
When `vals` is flattened, either `packed_info` or `ray_indices` must
be provided.
Args:
vals: Batched data with shape (n_rays, n_samples) or flattened data
with shape (all_samples,)
packed_info: Optional. A tensor of shape (n_rays, 2) that specifies
the start and count of each chunk in flattened `vals`, with in
total n_rays chunks. Only needed when `vals` is flattened.
ray_indices: Optional. A tensor of shape (all_samples,) that specifies
the ray index of each sample. Only needed when `vals` is flattened.
Examples:
.. code-block:: python
>>> # Batched data
>>> ray_samples = RaySamples(torch.rand(10, 100))
>>> # Flattened data
>>> ray_samples = RaySamples(
>>> torch.rand(1000),
>>> packed_info=torch.tensor([[0, 100], [100, 200], [300, 700]]),
>>> )
"""
vals: torch.Tensor
packed_info: Optional[torch.Tensor] = None
ray_indices: Optional[torch.Tensor] = None
def _to_cpp(self):
"""
Generate object to pass to C++
"""
spec = _C.RaySegmentsSpec()
spec.vals = self.vals.contiguous()
if self.packed_info is not None:
spec.chunk_starts = self.packed_info[:, 0].contiguous()
if self.chunk_cnts is not None:
spec.chunk_cnts = self.packed_info[:, 1].contiguous()
if self.ray_indices is not None:
spec.ray_indices = self.ray_indices.contiguous()
return spec
@classmethod
def _from_cpp(cls, spec):
"""
Generate object from C++
"""
if spec.chunk_starts is not None and spec.chunk_cnts is not None:
packed_info = torch.stack([spec.chunk_starts, spec.chunk_cnts], -1)
else:
packed_info = None
ray_indices = spec.ray_indices
return cls(
vals=spec.vals, packed_info=packed_info, ray_indices=ray_indices
)
@property
def device(self) -> torch.device:
return self.vals.device
@dataclass
class RayIntervals:
"""Ray intervals that supports batched and flattened data.
Each interval is defined by two edges (left and right). The attribute `vals`
stores the edges of all intervals along the rays. The attributes `is_left`
and `is_right` are for indicating whether each edge is a left or right edge.
This class unifies the representation of both continuous and non-continuous ray
intervals.
Note:
When `vals` is flattened, either `packed_info` or `ray_indices` must
be provided. Also both `is_left` and `is_right` must be provided.
Args:
vals: Batched data with shape (n_rays, n_edges) or flattened data
with shape (all_edges,)
packed_info: Optional. A tensor of shape (n_rays, 2) that specifies
the start and count of each chunk in flattened `vals`, with in
total n_rays chunks. Only needed when `vals` is flattened.
ray_indices: Optional. A tensor of shape (all_edges,) that specifies
the ray index of each edge. Only needed when `vals` is flattened.
is_left: Optional. A boolen tensor of shape (all_edges,) that specifies
whether each edge is a left edge. Only needed when `vals` is flattened.
is_right: Optional. A boolen tensor of shape (all_edges,) that specifies
whether each edge is a right edge. Only needed when `vals` is flattened.
Examples:
.. code-block:: python
>>> # Batched data
>>> ray_intervals = RayIntervals(torch.rand(10, 100))
>>> # Flattened data
>>> ray_intervals = RayIntervals(
>>> torch.rand(6),
>>> packed_info=torch.tensor([[0, 2], [2, 0], [2, 4]]),
>>> is_left=torch.tensor([True, False, True, True, True, False]),
>>> is_right=torch.tensor([False, True, False, True, True, True]),
>>> )
"""
vals: torch.Tensor
packed_info: Optional[torch.Tensor] = None
ray_indices: Optional[torch.Tensor] = None
is_left: Optional[torch.Tensor] = None
is_right: Optional[torch.Tensor] = None
def _to_cpp(self):
"""
Generate object to pass to C++
"""
spec = _C.RaySegmentsSpec()
spec.vals = self.vals.contiguous()
if self.packed_info is not None:
spec.chunk_starts = self.packed_info[:, 0].contiguous()
if self.packed_info is not None:
spec.chunk_cnts = self.packed_info[:, 1].contiguous()
if self.ray_indices is not None:
spec.ray_indices = self.ray_indices.contiguous()
if self.is_left is not None:
spec.is_left = self.is_left.contiguous()
if self.is_right is not None:
spec.is_right = self.is_right.contiguous()
return spec
@classmethod
def _from_cpp(cls, spec):
"""
Generate object from C++
"""
if spec.chunk_starts is not None and spec.chunk_cnts is not None:
packed_info = torch.stack([spec.chunk_starts, spec.chunk_cnts], -1)
else:
packed_info = None
ray_indices = spec.ray_indices
is_left = spec.is_left
is_right = spec.is_right
return cls(
vals=spec.vals,
packed_info=packed_info,
ray_indices=ray_indices,
is_left=is_left,
is_right=is_right,
)
@property
def device(self) -> torch.device:
return self.vals.device
from typing import Any
import torch
import torch.nn as nn
class AbstractEstimator(nn.Module):
"""An abstract Transmittance Estimator class for Sampling."""
def __init__(self) -> None:
super().__init__()
self.register_buffer("_dummy", torch.empty(0), persistent=False)
@property
def device(self) -> torch.device:
return self._dummy.device
def sampling(self, *args, **kwargs) -> Any:
raise NotImplementedError
def update_every_n_steps(self, *args, **kwargs) -> None:
raise NotImplementedError
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from ..grid import _enlarge_aabb, traverse_grids
from ..volrend import (
render_visibility_from_alpha,
render_visibility_from_density,
)
from .base import AbstractEstimator
class OccGridEstimator(AbstractEstimator):
"""Occupancy grid transmittance estimator for spatial skipping.
References: "Instant Neural Graphics Primitives."
Args:
roi_aabb: The axis-aligned bounding box of the region of interest. Useful for mapping
the 3D space to the grid.
resolution: The resolution of the grid. If an integer is given, the grid is assumed to
be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
levels: The number of levels of the grid. Default: 1.
"""
DIM: int = 3
def __init__(
self,
roi_aabb: Union[List[int], Tensor],
resolution: Union[int, List[int], Tensor] = 128,
levels: int = 1,
**kwargs,
) -> None:
super().__init__()
if "contraction_type" in kwargs:
raise ValueError(
"`contraction_type` is not supported anymore for nerfacc >= 0.4.0."
)
# check the resolution is legal
if isinstance(resolution, int):
resolution = [resolution] * self.DIM
if isinstance(resolution, (list, tuple)):
resolution = torch.tensor(resolution, dtype=torch.int32)
assert isinstance(resolution, Tensor), f"Invalid type: {resolution}!"
assert resolution.shape[0] == self.DIM, f"Invalid shape: {resolution}!"
# check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)):
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32)
assert isinstance(roi_aabb, Tensor), f"Invalid type: {roi_aabb}!"
assert roi_aabb.shape[0] == self.DIM * 2, f"Invalid shape: {roi_aabb}!"
# multiple levels of aabbs
aabbs = torch.stack(
[_enlarge_aabb(roi_aabb, 2**i) for i in range(levels)], dim=0
)
# total number of voxels
self.cells_per_lvl = int(resolution.prod().item())
self.levels = levels
# Buffers
self.register_buffer("resolution", resolution) # [3]
self.register_buffer("aabbs", aabbs) # [n_aabbs, 6]
self.register_buffer(
"occs", torch.zeros(self.levels * self.cells_per_lvl)
)
self.register_buffer(
"binaries",
torch.zeros([levels] + resolution.tolist(), dtype=torch.bool),
)
# Grid coords & indices
grid_coords = _meshgrid3d(resolution).reshape(
self.cells_per_lvl, self.DIM
)
self.register_buffer("grid_coords", grid_coords, persistent=False)
grid_indices = torch.arange(self.cells_per_lvl)
self.register_buffer("grid_indices", grid_indices, persistent=False)
@torch.no_grad()
def sampling(
self,
# rays
rays_o: Tensor, # [n_rays, 3]
rays_d: Tensor, # [n_rays, 3]
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
near_plane: float = 0.0,
far_plane: float = 1e10,
# rendering options
render_step_size: float = 1e-3,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
stratified: bool = False,
cone_angle: float = 0.0,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Sampling with spatial skipping.
Note:
This function is not differentiable to any inputs.
Args:
rays_o: Ray origins of shape (n_rays, 3).
rays_d: Normalized ray directions of shape (n_rays, 3).
sigma_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `sigma_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation density values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
alpha_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `alpha_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation opacity values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
near_plane: Optional. Near plane distance. Default: 0.0.
far_plane: Optional. Far plane distance. Default: 1e10.
render_step_size: Step size for marching. Default: 1e-3.
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
cone_angle: Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
Returns:
A tuple of {LongTensor, Tensor, Tensor}:
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_samples).
- **t_starts**: Per-sample start distance. Tensor with shape (n_samples,).
- **t_ends**: Per-sample end distance. Tensor with shape (n_samples,).
Examples:
.. code-block:: python
>>> ray_indices, t_starts, t_ends = grid.sampling(
>>> rays_o, rays_d, render_step_size=1e-3)
>>> t_mid = (t_starts + t_ends) / 2.0
>>> sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
near_planes = torch.full_like(rays_o[..., 0], fill_value=near_plane)
far_planes = torch.full_like(rays_o[..., 0], fill_value=far_plane)
if stratified:
near_planes += torch.rand_like(near_planes) * render_step_size
intervals, samples = traverse_grids(
rays_o,
rays_d,
self.binaries,
self.aabbs,
near_planes=near_planes,
far_planes=far_planes,
step_size=render_step_size,
cone_angle=cone_angle,
)
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
ray_indices = samples.ray_indices
packed_info = samples.packed_info
# skip invisible space
if (alpha_thre > 0.0 or early_stop_eps > 0.0) and (
sigma_fn is not None or alpha_fn is not None
):
alpha_thre = min(alpha_thre, self.occs.mean().item())
# Compute visibility of the samples, and filter out invisible samples
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N,)! Got {}".format(sigmas.shape)
masks = render_visibility_from_density(
t_starts=t_starts,
t_ends=t_ends,
sigmas=sigmas,
packed_info=packed_info,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N,)! Got {}".format(alphas.shape)
masks = render_visibility_from_alpha(
alphas=alphas,
packed_info=packed_info,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
return ray_indices, t_starts, t_ends
@torch.no_grad()
def update_every_n_steps(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 1e-2,
ema_decay: float = 0.95,
warmup_steps: int = 256,
n: int = 16,
) -> None:
"""Update the estimator every n steps during training.
Args:
step: Current training step.
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. Default: 256.
n: Update the grid every n steps. Default: 16.
"""
if not self.training:
raise RuntimeError(
"You should only call this function only during training. "
"Please call _update() directly if you want to update the "
"field during inference."
)
if step % n == 0 and self.training:
self._update(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=occ_thre,
ema_decay=ema_decay,
warmup_steps=warmup_steps,
)
@torch.no_grad()
def _get_all_cells(self) -> List[Tensor]:
"""Returns all cells of the grid."""
return [self.grid_indices] * self.levels
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells."""
lvl_indices = []
for lvl in range(self.levels):
uniform_indices = torch.randint(
self.cells_per_lvl, (n,), device=self.device
)
occupied_indices = torch.nonzero(self.binaries[lvl].flatten())[:, 0]
if n < len(occupied_indices):
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad()
def _update(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 0.01,
ema_decay: float = 0.95,
warmup_steps: int = 256,
) -> None:
"""Update the occ field in the EMA way."""
# sample cells
if step < warmup_steps:
lvl_indices = self._get_all_cells()
else:
N = self.cells_per_lvl // 4
lvl_indices = self._sample_uniform_and_occupied_cells(N)
for lvl, indices in enumerate(lvl_indices):
# infer occupancy: density * step_size
grid_coords = self.grid_coords[indices]
x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution
# voxel coordinates [0, 1]^3 -> world
x = self.aabbs[lvl, :3] + x * (
self.aabbs[lvl, 3:] - self.aabbs[lvl, :3]
)
occ = occ_eval_fn(x).squeeze(-1)
# ema update
cell_ids = lvl * self.cells_per_lvl + indices
self.occs[cell_ids] = torch.maximum(
self.occs[cell_ids] * ema_decay, occ
)
# suppose to use scatter max but emperically it is almost the same.
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self.binaries = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self.binaries.shape)
def _meshgrid3d(
res: Tensor, device: Union[torch.device, str] = "cpu"
) -> Tensor:
"""Create 3D grid coordinates."""
assert len(res) == 3
res = res.tolist()
return torch.stack(
torch.meshgrid(
[
torch.arange(res[0], dtype=torch.long),
torch.arange(res[1], dtype=torch.long),
torch.arange(res[2], dtype=torch.long),
],
indexing="ij",
),
dim=-1,
).to(device)
from typing import Callable, List, Literal, Optional, Tuple
import torch
from torch import Tensor
from ..data_specs import RayIntervals
from ..pdf import importance_sampling, searchsorted
from ..volrend import render_transmittance_from_density
from .base import AbstractEstimator
class PropNetEstimator(AbstractEstimator):
"""Proposal network transmittance estimator.
References: "Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields."
Args:
optimizer: The optimizer to use for the proposal networks.
scheduler: The learning rate scheduler to use for the proposal networks.
"""
def __init__(
self,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
super().__init__()
self.optimizer = optimizer
self.scheduler = scheduler
self.prop_cache: List = []
@torch.no_grad()
def sampling(
self,
prop_sigma_fns: List[Callable],
prop_samples: List[int],
num_samples: int,
# rendering options
n_rays: int,
near_plane: float,
far_plane: float,
sampling_type: Literal["uniform", "lindisp"] = "lindisp",
# training options
stratified: bool = False,
requires_grad: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Sampling with CDFs from proposal networks.
Note:
When `requires_grad` is `True`, the gradients are allowed to flow
through the proposal networks, and the outputs of the proposal
networks are cached to update them later when calling `update_every_n_steps()`
Args:
prop_sigma_fns: Proposal network evaluate functions. It should be a list
of functions that take in samples {t_starts (n_rays, n_samples),
t_ends (n_rays, n_samples)} and returns the post-activation densities
(n_rays, n_samples).
prop_samples: Number of samples to draw from each proposal network. Should
be the same length as `prop_sigma_fns`.
num_samples: Number of samples to draw in the end.
n_rays: Number of rays.
near_plane: Near plane.
far_plane: Far plane.
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
"lindisp".
stratified: Whether to use stratified sampling. Default to `False`.
requires_grad: Whether to allow gradients to flow through the proposal
networks. Default to `False`.
Returns:
A tuple of {Tensor, Tensor}:
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
"""
assert len(prop_sigma_fns) == len(prop_samples), (
"The number of proposal networks and the number of samples "
"should be the same."
)
cdfs = torch.cat(
[
torch.zeros((n_rays, 1), device=self.device),
torch.ones((n_rays, 1), device=self.device),
],
dim=-1,
)
intervals = RayIntervals(vals=cdfs)
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
intervals, _ = importance_sampling(
intervals, cdfs, level_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
with torch.set_grad_enabled(requires_grad):
sigmas = level_fn(t_starts, t_ends)
assert sigmas.shape == t_starts.shape
trans, _ = render_transmittance_from_density(
t_starts, t_ends, sigmas
)
cdfs = 1.0 - torch.cat(
[trans, torch.zeros_like(trans[:, :1])], dim=-1
)
if requires_grad:
self.prop_cache.append((intervals, cdfs))
intervals, _ = importance_sampling(
intervals, cdfs, num_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
if requires_grad:
self.prop_cache.append((intervals, None))
return t_starts, t_ends
@torch.enable_grad()
def compute_loss(self, trans: Tensor, loss_scaler: float = 1.0) -> Tensor:
"""Compute the loss for the proposal networks.
Args:
trans: The transmittance of all samples. Shape (n_rays, num_samples).
loss_scaler: The loss scaler. Default to 1.0.
Returns:
The loss for the proposal networks.
"""
if len(self.prop_cache) == 0:
return torch.zeros((), device=self.device)
intervals, _ = self.prop_cache.pop()
# get cdfs at all edges of intervals
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
cdfs = cdfs.detach()
loss = 0.0
while self.prop_cache:
prop_intervals, prop_cdfs = self.prop_cache.pop()
loss += _pdf_loss(intervals, cdfs, prop_intervals, prop_cdfs).mean()
return loss * loss_scaler
@torch.enable_grad()
def update_every_n_steps(
self,
trans: Tensor,
requires_grad: bool = False,
loss_scaler: float = 1.0,
) -> float:
"""Update the estimator every n steps during training.
Args:
trans: The transmittance of all samples. Shape (n_rays, num_samples).
requires_grad: Whether to allow gradients to flow through the proposal
networks. Default to `False`.
loss_scaler: The loss scaler to use. Default to 1.0.
Returns:
The loss of the proposal networks for logging (a float scalar).
"""
if requires_grad:
return self._update(trans=trans, loss_scaler=loss_scaler)
else:
if self.scheduler is not None:
self.scheduler.step()
return 0.0
@torch.enable_grad()
def _update(self, trans: Tensor, loss_scaler: float = 1.0) -> float:
assert len(self.prop_cache) > 0
assert self.optimizer is not None, "No optimizer is provided."
loss = self.compute_loss(trans, loss_scaler)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
return loss.item()
def get_proposal_requires_grad_fn(
target: float = 5.0, num_steps: int = 1000
) -> Callable:
schedule = lambda s: min(s / num_steps, 1.0) * target
steps_since_last_grad = 0
def proposal_requires_grad_fn(step: int) -> bool:
nonlocal steps_since_last_grad
target_steps_since_last_grad = schedule(step)
requires_grad = steps_since_last_grad > target_steps_since_last_grad
if requires_grad:
steps_since_last_grad = 0
steps_since_last_grad += 1
return requires_grad
return proposal_requires_grad_fn
def _transform_stot(
transform_type: Literal["uniform", "lindisp"],
s_vals: torch.Tensor,
t_min: torch.Tensor,
t_max: torch.Tensor,
) -> torch.Tensor:
if transform_type == "uniform":
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
elif transform_type == "lindisp":
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
else:
raise ValueError(f"Unknown transform_type: {transform_type}")
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
return icontract_fn(s_vals)
def _pdf_loss(
segments_query: RayIntervals,
cdfs_query: torch.Tensor,
segments_key: RayIntervals,
cdfs_key: torch.Tensor,
eps: float = 1e-7,
) -> torch.Tensor:
ids_left, ids_right = searchsorted(segments_key, segments_query)
if segments_query.vals.dim() > 1:
w = cdfs_query[..., 1:] - cdfs_query[..., :-1]
ids_left = ids_left[..., :-1]
ids_right = ids_right[..., 1:]
else:
# TODO: not tested for this branch.
assert segments_query.is_left is not None
assert segments_query.is_right is not None
w = (
cdfs_query[segments_query.is_right]
- cdfs_query[segments_query.is_left]
)
ids_left = ids_left[segments_query.is_left]
ids_right = ids_right[segments_query.is_right]
w_outer = cdfs_key.gather(-1, ids_right) - cdfs_key.gather(-1, ids_left)
return torch.clip(w - w_outer, min=0) ** 2 / (w + eps)
def _outer(
t0_starts: torch.Tensor,
t0_ends: torch.Tensor,
t1_starts: torch.Tensor,
t1_ends: torch.Tensor,
y1: torch.Tensor,
) -> torch.Tensor:
"""
Args:
t0_starts: (..., S0).
t0_ends: (..., S0).
t1_starts: (..., S1).
t1_ends: (..., S1).
y1: (..., S1).
"""
cy1 = torch.cat(
[torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
)
idx_lo = (
torch.searchsorted(
t1_starts.contiguous(), t0_starts.contiguous(), side="right"
)
- 1
)
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(
t1_ends.contiguous(), t0_ends.contiguous(), side="right"
)
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def _lossfun_outer(
t: torch.Tensor,
w: torch.Tensor,
t_env: torch.Tensor,
w_env: torch.Tensor,
):
"""
Args:
t: interval edges, (..., S + 1).
w: weights, (..., S).
t_env: interval edges of the upper bound enveloping historgram, (..., S + 1).
w_env: weights that should upper bound the inner (t,w) histogram, (..., S).
"""
eps = torch.finfo(t.dtype).eps
w_outer = _outer(
t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env
)
return torch.clip(w - w_outer, min=0) ** 2 / (w + eps)
"""
Copyright (c) 2022 Ruilong Li @ UC Berkeley
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Callable, List, Union
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from . import cuda as _C
from .data_specs import RayIntervals, RaySamples
import nerfacc.cuda as _C
from .contraction import ContractionType, contract_inv
@torch.no_grad()
def ray_aabb_intersect(
rays_o: Tensor,
rays_d: Tensor,
aabbs: Tensor,
near_plane: float = -float("inf"),
far_plane: float = float("inf"),
miss_value: float = float("inf"),
) -> Tuple[Tensor, Tensor, Tensor]:
"""Ray-AABB intersection.
Args:
rays_o: (n_rays, 3) Ray origins.
rays_d: (n_rays, 3) Normalized ray directions.
aabbs: (m, 6) Axis-aligned bounding boxes {xmin, ymin, zmin, xmax, ymax, zmax}.
near_plane: Optional. Near plane. Default to -infinity.
far_plane: Optional. Far plane. Default to infinity.
miss_value: Optional. Value to use for tmin and tmax when there is no intersection.
Default to infinity.
Returns:
t_mins: (n_rays, m) tmin for each ray-AABB pair.
t_maxs: (n_rays, m) tmax for each ray-AABB pair.
hits: (n_rays, m) whether each ray-AABB pair intersects.
"""
assert rays_o.ndim == 2 and rays_o.shape[-1] == 3
assert rays_d.ndim == 2 and rays_d.shape[-1] == 3
assert aabbs.ndim == 2 and aabbs.shape[-1] == 6
t_mins, t_maxs, hits = _C.ray_aabb_intersect(
rays_o.contiguous(),
rays_d.contiguous(),
aabbs.contiguous(),
near_plane,
far_plane,
miss_value,
)
return t_mins, t_maxs, hits
def _ray_aabb_intersect(
rays_o: Tensor,
rays_d: Tensor,
aabbs: Tensor,
near_plane: float = -float("inf"),
far_plane: float = float("inf"),
miss_value: float = float("inf"),
) -> Tuple[Tensor, Tensor, Tensor]:
"""Ray-AABB intersection.
Functionally the same with `ray_aabb_intersect()`, but slower with pure Pytorch.
"""
# TODO: check torch.scatter_reduce_
# from torch_scatter import scatter_max
# Compute the minimum and maximum bounds of the AABBs
aabb_min = aabbs[:, :3]
aabb_max = aabbs[:, 3:]
# Compute the intersection distances between the ray and each of the six AABB planes
t1 = (aabb_min[None, :, :] - rays_o[:, None, :]) / rays_d[:, None, :]
t2 = (aabb_max[None, :, :] - rays_o[:, None, :]) / rays_d[:, None, :]
# Compute the maximum tmin and minimum tmax for each AABB
t_mins = torch.max(torch.min(t1, t2), dim=-1)[0]
t_maxs = torch.min(torch.max(t1, t2), dim=-1)[0]
# Compute whether each ray-AABB pair intersects
hits = (t_maxs > t_mins) & (t_maxs > 0)
# Clip the tmin and tmax values to the near and far planes
t_mins = torch.clamp(t_mins, min=near_plane, max=far_plane)
t_maxs = torch.clamp(t_maxs, min=near_plane, max=far_plane)
# Set the tmin and tmax values to miss_value if there is no intersection
t_mins = torch.where(hits, t_mins, miss_value)
t_maxs = torch.where(hits, t_maxs, miss_value)
return t_mins, t_maxs, hits
@torch.no_grad()
def query_grid(
samples: torch.Tensor,
grid_roi: torch.Tensor,
grid_values: torch.Tensor,
grid_type: ContractionType,
):
"""Query grid values given coordinates.
def traverse_grids(
# rays
rays_o: Tensor, # [n_rays, 3]
rays_d: Tensor, # [n_rays, 3]
# grids
binaries: Tensor, # [m, resx, resy, resz]
aabbs: Tensor, # [m, 6]
# options
near_planes: Optional[Tensor] = None, # [n_rays]
far_planes: Optional[Tensor] = None, # [n_rays]
step_size: Optional[float] = 1e-3,
cone_angle: Optional[float] = 0.0,
) -> Tuple[RayIntervals, RaySamples]:
"""Ray Traversal within Multiple Grids.
Note:
This function is not differentiable to any inputs.
Args:
samples: (n_samples, 3) tensor of coordinates.
grid_roi: (6,) region of interest of the grid. Usually it should be
accquired from the grid itself using `grid.roi_aabb`.
grid_values: A 4D tensor of grid values in the shape of (nlvl, resx, resy, resz).
grid_type: Contraction type of the grid. Usually it should be
accquired from the grid itself using `grid.contraction_type`.
rays_o: (n_rays, 3) Ray origins.
rays_d: (n_rays, 3) Normalized ray directions.
binary_grids: (m, resx, resy, resz) Multiple binary grids with the same resolution.
aabbs: (m, 6) Axis-aligned bounding boxes {xmin, ymin, zmin, xmax, ymax, zmax}.
near_planes: Optional. (n_rays,) Near planes for the traversal to start. Default to 0.
far_planes: Optional. (n_rays,) Far planes for the traversal to end. Default to infinity.
step_size: Optional. Step size for ray traversal. Default to 1e-3.
cone_angle: Optional. Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
Returns:
(n_samples) values for those samples queried from the grid.
A :class:`RayIntervals` object containing the intervals of the ray traversal, and
a :class:`RaySamples` object containing the samples within each interval.
"""
assert samples.dim() == 2 and samples.size(-1) == 3
assert grid_roi.dim() == 1 and grid_roi.size(0) == 6
assert grid_values.dim() == 4
assert isinstance(grid_type, ContractionType)
return _C.grid_query(
samples.contiguous(),
grid_roi.contiguous(),
grid_values.contiguous(),
grid_type.to_cpp_version(),
# Compute ray aabb intersection for all levels of grid. [n_rays, m]
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, aabbs)
if near_planes is None:
near_planes = torch.zeros_like(rays_o[:, 0])
if far_planes is None:
far_planes = torch.full_like(rays_o[:, 0], float("inf"))
intervals, samples = _C.traverse_grids(
# rays
rays_o.contiguous(), # [n_rays, 3]
rays_d.contiguous(), # [n_rays, 3]
# grids
binaries.contiguous(), # [m, resx, resy, resz]
aabbs.contiguous(), # [m, 6]
# intersections
t_mins.contiguous(), # [n_rays, m]
t_maxs.contiguous(), # [n_rays, m]
hits.contiguous(), # [n_rays, m]
# options
near_planes.contiguous(), # [n_rays]
far_planes.contiguous(), # [n_rays]
step_size,
cone_angle,
True,
True,
)
return RayIntervals._from_cpp(intervals), RaySamples._from_cpp(samples)
class Grid(nn.Module):
"""An abstract Grid class.
def _enlarge_aabb(aabb, factor: float) -> Tensor:
center = (aabb[:3] + aabb[3:]) / 2
extent = (aabb[3:] - aabb[:3]) / 2
return torch.cat([center - extent * factor, center + extent * factor])
The grid is used as a cache of the 3D space to indicate whether each voxel
area is important or not for the differentiable rendering process. The
ray marching function (see :func:`nerfacc.ray_marching`) would use the
grid to skip the unimportant voxel areas.
To work with :func:`nerfacc.ray_marching`, three attributes must exist:
- :attr:`roi_aabb`: The axis-aligned bounding box of the region of interest.
- :attr:`binary`: A 4D binarized tensor of shape {nlvl, resx, resy, resz}, \
with torch.bool data type.
- :attr:`contraction_type`: The contraction type of the grid, indicating how \
the 3D space is mapped to the grid.
def _query(x: Tensor, data: Tensor, base_aabb: Tensor) -> Tensor:
"""
Query the grid values at the given points.
def __init__(self, *args, **kwargs):
super().__init__()
self.register_buffer("_dummy", torch.empty(0), persistent=False)
@property
def device(self) -> torch.device:
return self._dummy.device
@property
def roi_aabb(self) -> torch.Tensor:
"""The axis-aligned bounding box of the region of interest.
Its is a shape (6,) tensor in the format of {minx, miny, minz, maxx, maxy, maxz}.
"""
if hasattr(self, "_roi_aabb"):
return getattr(self, "_roi_aabb")
else:
raise NotImplementedError("please set an attribute named _roi_aabb")
@property
def binary(self) -> torch.Tensor:
"""A 4-dim binarized tensor with torch.bool data type.
The tensor is of shape (nlvl, resx, resy, resz), in which each boolen value
represents whether the corresponding voxel should be kept or not.
"""
if hasattr(self, "_binary"):
return getattr(self, "_binary")
else:
raise NotImplementedError("please set an attribute named _binary")
@property
def contraction_type(self) -> ContractionType:
"""The contraction type of the grid.
The contraction type is an indicator of how the 3D space is contracted
to this voxel grid. See :class:`nerfacc.ContractionType` for more details.
"""
if hasattr(self, "_contraction_type"):
return getattr(self, "_contraction_type")
else:
raise NotImplementedError(
"please set an attribute named _contraction_type"
)
class OccupancyGrid(Grid):
"""Occupancy grid: whether each voxel area is occupied or not.
This function assumes the aabbs of multiple grids are 2x scaled.
Args:
roi_aabb: The axis-aligned bounding box of the region of interest. Useful for mapping
the 3D space to the grid.
resolution: The resolution of the grid. If an integer is given, the grid is assumed to
be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
contraction_type: The contraction type of the grid. See :class:`nerfacc.ContractionType`
for more details. Default: :attr:`nerfacc.ContractionType.AABB`.
levels: The number of levels of the grid. Default: 1.
x: (N, 3) tensor of points to query.
data: (m, resx, resy, resz) tensor of grid values
base_aabb: (6,) aabb of base level grid.
"""
# normalize so that the base_aabb is [0, 1]^3
aabb_min, aabb_max = torch.split(base_aabb, 3, dim=0)
x_norm = (x - aabb_min) / (aabb_max - aabb_min)
# if maxval is almost zero, it will trigger frexpf to output 0
# for exponent, which is not what we want.
maxval = (x_norm - 0.5).abs().max(dim=-1).values
maxval = torch.clamp(maxval, min=0.1)
# compute the mip level
exponent = torch.frexp(maxval)[1].long()
mip = torch.clamp(exponent + 1, min=0)
selector = mip < data.shape[0]
# use the mip to re-normalize all points to [0, 1].
scale = 2**mip
x_unit = (x_norm - 0.5) / scale[:, None] + 0.5
# map to the grid index
resolution = torch.tensor(data.shape[1:], device=x.device)
ix = (x_unit * resolution).long()
ix = torch.clamp(ix, max=resolution - 1)
mip = torch.clamp(mip, max=data.shape[0] - 1)
NUM_DIM: int = 3
def __init__(
self,
roi_aabb: Union[List[int], torch.Tensor],
resolution: Union[int, List[int], torch.Tensor] = 128,
contraction_type: ContractionType = ContractionType.AABB,
levels: int = 1,
) -> None:
super().__init__()
if isinstance(resolution, int):
resolution = [resolution] * self.NUM_DIM
if isinstance(resolution, (list, tuple)):
resolution = torch.tensor(resolution, dtype=torch.int32)
assert isinstance(
resolution, torch.Tensor
), f"Invalid type: {type(resolution)}"
assert resolution.shape == (
self.NUM_DIM,
), f"Invalid shape: {resolution.shape}"
if isinstance(roi_aabb, (list, tuple)):
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32)
assert isinstance(
roi_aabb, torch.Tensor
), f"Invalid type: {type(roi_aabb)}"
assert roi_aabb.shape == torch.Size(
[self.NUM_DIM * 2]
), f"Invalid shape: {roi_aabb.shape}"
if levels > 1:
assert (
contraction_type == ContractionType.AABB
), "For multi-res occupancy grid, contraction is not supported yet."
# total number of voxels
self.num_cells_per_lvl = int(resolution.prod().item())
self.levels = levels
# required attributes
self.register_buffer("_roi_aabb", roi_aabb)
self.register_buffer(
"_binary",
torch.zeros([levels] + resolution.tolist(), dtype=torch.bool),
)
self._contraction_type = contraction_type
# helper attributes
self.register_buffer("resolution", resolution)
self.register_buffer(
"occs", torch.zeros(self.levels * self.num_cells_per_lvl)
)
# Grid coords & indices
grid_coords = _meshgrid3d(resolution).reshape(
self.num_cells_per_lvl, self.NUM_DIM
)
self.register_buffer("grid_coords", grid_coords, persistent=False)
grid_indices = torch.arange(self.num_cells_per_lvl)
self.register_buffer("grid_indices", grid_indices, persistent=False)
@torch.no_grad()
def _get_all_cells(self) -> List[torch.Tensor]:
"""Returns all cells of the grid."""
return [self.grid_indices] * self.levels
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[torch.Tensor]:
"""Samples both n uniform and occupied cells."""
lvl_indices = []
for lvl in range(self.levels):
uniform_indices = torch.randint(
self.num_cells_per_lvl, (n,), device=self.device
)
occupied_indices = torch.nonzero(self._binary[lvl].flatten())[:, 0]
if n < len(occupied_indices):
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad()
def _update(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 0.01,
ema_decay: float = 0.95,
warmup_steps: int = 256,
) -> None:
"""Update the occ field in the EMA way."""
# sample cells
if step < warmup_steps:
lvl_indices = self._get_all_cells()
else:
N = self.num_cells_per_lvl // 4
lvl_indices = self._sample_uniform_and_occupied_cells(N)
for lvl, indices in enumerate(lvl_indices):
# infer occupancy: density * step_size
grid_coords = self.grid_coords[indices]
x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution
if self._contraction_type == ContractionType.UN_BOUNDED_SPHERE:
# only the points inside the sphere are valid
mask = (x - 0.5).norm(dim=1) < 0.5
x = x[mask]
indices = indices[mask]
# voxel coordinates [0, 1]^3 -> world
x = contract_inv(
(x - 0.5) * (2**lvl) + 0.5,
roi=self._roi_aabb,
type=self._contraction_type,
)
occ = occ_eval_fn(x).squeeze(-1)
# ema update
cell_ids = lvl * self.num_cells_per_lvl + indices
self.occs[cell_ids] = torch.maximum(
self.occs[cell_ids] * ema_decay, occ
)
# suppose to use scatter max but emperically it is almost the same.
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self._binary = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self._binary.shape)
@torch.no_grad()
def every_n_step(
self,
step: int,
occ_eval_fn: Callable,
occ_thre: float = 1e-2,
ema_decay: float = 0.95,
warmup_steps: int = 256,
n: int = 16,
) -> None:
"""Update the grid every n steps during training.
Args:
step: Current training step.
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. Default: 256.
n: Update the grid every n steps. Default: 16.
"""
if not self.training:
raise RuntimeError(
"You should only call this function only during training. "
"Please call _update() directly if you want to update the "
"field during inference."
)
if step % n == 0 and self.training:
self._update(
step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=occ_thre,
ema_decay=ema_decay,
warmup_steps=warmup_steps,
)
@torch.no_grad()
def query_occ(self, samples: torch.Tensor) -> torch.Tensor:
"""Query the occupancy field at the given samples.
Args:
samples: Samples in the world coordinates. (n_samples, 3)
Returns:
Occupancy values at the given samples. (n_samples,)
"""
return query_grid(
samples,
self._roi_aabb,
self.binary,
self.contraction_type,
)
def _meshgrid3d(
res: torch.Tensor, device: Union[torch.device, str] = "cpu"
) -> torch.Tensor:
"""Create 3D grid coordinates."""
assert len(res) == 3
res = res.tolist()
return torch.stack(
torch.meshgrid(
[
torch.arange(res[0], dtype=torch.long),
torch.arange(res[1], dtype=torch.long),
torch.arange(res[2], dtype=torch.long),
],
indexing="ij",
),
dim=-1,
).to(device)
return data[mip, ix[:, 0], ix[:, 1], ix[:, 2]] * selector, selector
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple
import torch
from torch import Tensor
import nerfacc.cuda as _C
@torch.no_grad()
def ray_aabb_intersect(
rays_o: Tensor, rays_d: Tensor, aabb: Tensor
) -> Tuple[Tensor, Tensor]:
"""Ray AABB Test.
Note:
this function is not differentiable to any inputs.
Args:
rays_o: Ray origins of shape (n_rays, 3).
rays_d: Normalized ray directions of shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. \
Tensor with shape (6)
Returns:
Ray AABB intersection {t_min, t_max} with shape (n_rays) respectively. \
Note the t_min is clipped to minimum zero. 1e10 means no intersection.
Examples:
.. code-block:: python
aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device="cuda:0")
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb)
"""
if rays_o.is_cuda and rays_d.is_cuda and aabb.is_cuda:
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
aabb = aabb.contiguous()
t_min, t_max = _C.ray_aabb_intersect(rays_o, rays_d, aabb)
else:
raise NotImplementedError("Only support cuda inputs.")
return t_min, t_max
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Optional, Tuple
from typing import Optional
import torch
from torch import Tensor
import nerfacc.cuda as _C
@torch.no_grad()
def pack_info(ray_indices: Tensor, n_rays: Optional[int] = None) -> Tensor:
"""Pack `ray_indices` to `packed_info`. Useful for converting per sample data to per ray data.
def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Pack per-ray data (n_rays, n_samples, D) to (all_samples, D) based on mask.
Note:
this function is not differentiable to any inputs.
Args:
data: Tensor with shape (n_rays, n_samples, D).
mask: Boolen tensor with shape (n_rays, n_samples).
ray_indices: Ray indices of the samples. LongTensor with shape (n_sample).
n_rays: Number of rays. If None, it is inferred from `ray_indices`. Default is None.
Returns:
Tuple of Tensors including packed data (all_samples, D), \
and packed_info (n_rays, 2) which stores the start index of the sample,
and the number of samples kept for each ray. \
A LongTensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
Examples:
Example:
.. code-block:: python
data = torch.rand((10, 3, 4), device="cuda:0")
mask = data.rand((10, 3), dtype=torch.bool, device="cuda:0")
packed_data, packed_info = pack(data, mask)
print(packed_data.shape, packed_info.shape)
"""
assert data.dim() == 3, "data must be with shape of (n_rays, n_samples, D)."
assert (
mask.shape == data.shape[:2]
), "mask must be with shape of (n_rays, n_samples)."
assert mask.dtype == torch.bool, "mask must be a boolean tensor."
packed_data = data[mask]
num_steps = mask.sum(dim=-1, dtype=torch.int32)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int32)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
return packed_data, packed_info
@torch.no_grad()
def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor:
"""Pack `ray_indices` to `packed_info`. Useful for converting per sample data to per ray data.
Note:
this function is not differentiable to any inputs.
Args:
ray_indices: Ray index of each sample. LongTensor with shape (n_sample).
>>> ray_indices = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 2], device="cuda")
>>> packed_info = pack_info(ray_indices, n_rays=3)
>>> packed_info
tensor([[0, 2], [2, 3], [5, 4]], device='cuda:0')
Returns:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
"""
assert (
ray_indices.dim() == 1
), "ray_indices must be a 1D tensor with shape (n_samples)."
if ray_indices.is_cuda:
ray_indices = ray_indices
device = ray_indices.device
dtype = ray_indices.dtype
if n_rays is None:
n_rays = int(ray_indices.max()) + 1
# else:
# assert n_rays > ray_indices.max()
src = torch.ones_like(ray_indices, dtype=torch.int)
num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int)
num_steps.scatter_add_(0, ray_indices, src)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
n_rays = ray_indices.max().item() + 1
chunk_cnts = torch.zeros((n_rays,), device=device, dtype=dtype)
chunk_cnts.index_add_(0, ray_indices, torch.ones_like(ray_indices))
chunk_starts = chunk_cnts.cumsum(dim=0, dtype=dtype) - chunk_cnts
packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
else:
raise NotImplementedError("Only support cuda inputs.")
return packed_info
@torch.no_grad()
def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note:
this function is not differentiable to any inputs.
Args:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns:
Ray index of each sample. LongTensor with shape (n_sample).
Examples:
.. code-block:: python
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with near far plane.
packed_info, t_starts, t_ends = ray_marching(
rays_o, rays_d, near_plane=0.1, far_plane=1.0, render_step_size=1e-3
)
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info.
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype)
"""
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous(), n_samples)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices
def unpack_data(
packed_info: Tensor,
data: Tensor,
n_samples: Optional[int] = None,
pad_value: float = 0.0,
) -> Tensor:
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D).
Args:
packed_info (Tensor): Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
data: Packed data to unpack. Tensor with shape (n_samples, D).
n_samples (int): Optional Number of samples per ray. If not provided, it \
will be inferred from the packed_info.
pad_value: Value to pad the unpacked data.
Returns:
Unpacked data (n_rays, n_samples, D).
Examples:
.. code-block:: python
rays_o = torch.rand((128, 3), device="cuda:0")
rays_d = torch.randn((128, 3), device="cuda:0")
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
# Ray marching with aabb.
scene_aabb = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], device="cuda:0")
packed_info, t_starts, t_ends = ray_marching(
rays_o, rays_d, scene_aabb=scene_aabb, render_step_size=1e-2
)
print(t_starts.shape) # torch.Size([all_samples, 1])
t_starts = unpack_data(packed_info, t_starts, n_samples=1024)
print(t_starts.shape) # torch.Size([128, 1024, 1])
"""
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
assert (
data.dim() == 2
), "data must be a 2D tensor with shape (n_samples, D)."
if n_samples is None:
n_samples = packed_info[:, 1].max().item()
return _UnpackData.apply(packed_info, data, n_samples, pad_value)
class _UnpackData(torch.autograd.Function):
"""Unpack packed data (all_samples, D) to per-ray data (n_rays, n_samples, D)."""
@staticmethod
def forward(
ctx,
packed_info: Tensor,
data: Tensor,
n_samples: int,
pad_value: float = 0.0,
) -> Tensor:
# shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous()
data = data.contiguous()
if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info)
ctx.n_samples = n_samples
return _C.unpack_data(packed_info, data, n_samples, pad_value)
@staticmethod
def backward(ctx, grad: Tensor):
# shape of the grad should be (n_rays, n_samples, D)
packed_info = ctx.saved_tensors[0]
n_samples = ctx.n_samples
mask = _C.unpack_info_to_mask(packed_info, n_samples)
packed_grad = grad[mask].contiguous()
return None, packed_grad, None, None
from typing import Optional
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
from typing import Tuple, Union
import torch
from torch import Tensor
import nerfacc.cuda as _C
class PDFOuter(torch.autograd.Function):
@staticmethod
def forward(
ctx,
ts: Tensor,
weights: Tensor,
masks: Optional[Tensor],
ts_query: Tensor,
masks_query: Optional[Tensor],
):
assert ts.dim() == weights.dim() == ts_query.dim() == 2
assert ts.shape[0] == weights.shape[0] == ts_query.shape[0]
assert ts.shape[1] == weights.shape[1] + 1
ts = ts.contiguous()
weights = weights.contiguous()
ts_query = ts_query.contiguous()
masks = masks.contiguous() if masks is not None else None
masks_query = (
masks_query.contiguous() if masks_query is not None else None
)
weights_query = _C.pdf_readout(
ts, weights, masks, ts_query, masks_query
)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(ts, masks, ts_query, masks_query)
return weights_query
@staticmethod
def backward(ctx, weights_query_grads: Tensor):
weights_query_grads = weights_query_grads.contiguous()
ts, masks, ts_query, masks_query = ctx.saved_tensors
weights_grads = _C.pdf_readout(
ts_query, weights_query_grads, masks_query, ts, masks
)
return None, weights_grads, None, None, None
from . import cuda as _C
from .data_specs import RayIntervals, RaySamples
def searchsorted(
sorted_sequence: Union[RayIntervals, RaySamples],
values: Union[RayIntervals, RaySamples],
) -> Tuple[Tensor, Tensor]:
"""Searchsorted that supports flattened tensor.
This function returns {`ids_left`, `ids_right`} such that:
pdf_outer = PDFOuter.apply
`sorted_sequence.vals.gather(-1, ids_left) <= values.vals < sorted_sequence.vals.gather(-1, ids_right)`
Note:
When values is out of range of sorted_sequence, we return the
corresponding ids as if the values is clipped to the range of
sorted_sequence. See the example below.
@torch.no_grad()
def pdf_sampling(
t: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
padding: float = 0.01,
Args:
sorted_sequence: A :class:`RayIntervals` or :class:`RaySamples` object. We assume
the `sorted_sequence.vals` is acendingly sorted for each ray.
values: A :class:`RayIntervals` or :class:`RaySamples` object.
Returns:
A tuple of LongTensor:
- **ids_left**: A LongTensor with the same shape as `values.vals`.
- **ids_right**: A LongTensor with the same shape as `values.vals`.
Example:
>>> sorted_sequence = RayIntervals(
... vals=torch.tensor([0.0, 1.0, 0.0, 1.0, 2.0], device="cuda"),
... packed_info=torch.tensor([[0, 2], [2, 3]], device="cuda"),
... )
>>> values = RayIntervals(
... vals=torch.tensor([0.5, 1.5, 2.5], device="cuda"),
... packed_info=torch.tensor([[0, 1], [1, 2]], device="cuda"),
... )
>>> ids_left, ids_right = searchsorted(sorted_sequence, values)
>>> ids_left
tensor([0, 3, 3], device='cuda:0')
>>> ids_right
tensor([1, 4, 4], device='cuda:0')
>>> sorted_sequence.vals.gather(-1, ids_left)
tensor([0., 1., 1.], device='cuda:0')
>>> sorted_sequence.vals.gather(-1, ids_right)
tensor([1., 2., 2.], device='cuda:0')
"""
ids_left, ids_right = _C.searchsorted(
values._to_cpp(), sorted_sequence._to_cpp()
)
return ids_left, ids_right
def importance_sampling(
intervals: RayIntervals,
cdfs: Tensor,
n_intervals_per_ray: Union[Tensor, int],
stratified: bool = False,
single_jitter: bool = False,
masks: Optional[torch.Tensor] = None,
):
assert t.shape[0] == weights.shape[0]
assert t.shape[1] == weights.shape[1] + 1
if masks is not None:
assert t.shape[0] == masks.shape[0]
t_new = _C.pdf_sampling(
t.contiguous(),
weights.contiguous(),
n_samples + 1, # be careful here!
padding,
) -> Tuple[RayIntervals, RaySamples]:
"""Importance sampling that supports flattened tensor.
Given a set of intervals and the corresponding CDFs at the interval edges,
this function performs inverse transform sampling to create a new set of
intervals and samples. Stratified sampling is also supported.
Args:
intervals: A :class:`RayIntervals` object that specifies the edges of the
intervals along the rays.
cdfs: The CDFs at the interval edges. It has the same shape as
`intervals.vals`.
n_intervals_per_ray: Resample each ray to have this many intervals.
If it is a tensor, it must be of shape (n_rays,). If it is an int,
it is broadcasted to all rays.
stratified: If True, perform stratified sampling.
Returns:
A tuple of {:class:`RayIntervals`, :class:`RaySamples`}:
- **intervals**: A :class:`RayIntervals` object. If `n_intervals_per_ray` is an int, \
`intervals.vals` will has the shape of (n_rays, n_intervals_per_ray + 1). \
If `n_intervals_per_ray` is a tensor, we assume each ray results \
in a different number of intervals. In this case, `intervals.vals` \
will has the shape of (all_edges,), the attributes `packed_info`, \
`ray_indices`, `is_left` and `is_right` will be accessable.
- **samples**: A :class:`RaySamples` object. If `n_intervals_per_ray` is an int, \
`samples.vals` will has the shape of (n_rays, n_intervals_per_ray). \
If `n_intervals_per_ray` is a tensor, we assume each ray results \
in a different number of intervals. In this case, `samples.vals` \
will has the shape of (all_samples,), the attributes `packed_info` and \
`ray_indices` will be accessable.
Example:
.. code-block:: python
>>> intervals = RayIntervals(
... vals=torch.tensor([0.0, 1.0, 0.0, 1.0, 2.0], device="cuda"),
... packed_info=torch.tensor([[0, 2], [2, 3]], device="cuda"),
... )
>>> cdfs = torch.tensor([0.0, 0.5, 0.0, 0.5, 1.0], device="cuda")
>>> n_intervals_per_ray = 2
>>> intervals, samples = importance_sampling(intervals, cdfs, n_intervals_per_ray)
>>> intervals.vals
tensor([[0.0000, 0.5000, 1.0000],
[0.0000, 1.0000, 2.0000]], device='cuda:0')
>>> samples.vals
tensor([[0.2500, 0.7500],
[0.5000, 1.5000]], device='cuda:0')
"""
if isinstance(n_intervals_per_ray, Tensor):
n_intervals_per_ray = n_intervals_per_ray.contiguous()
intervals, samples = _C.importance_sampling(
intervals._to_cpp(),
cdfs.contiguous(),
n_intervals_per_ray,
stratified,
single_jitter,
masks.contiguous() if masks is not None else None,
)
return t_new # [n_ray, n_samples+1]
return RayIntervals._from_cpp(intervals), RaySamples._from_cpp(samples)
def _sample_from_weighted(
bins: Tensor,
weights: Tensor,
num_samples: int,
stratified: bool = False,
vmin: float = -torch.inf,
vmax: float = torch.inf,
) -> Tuple[Tensor, Tensor]:
import torch.nn.functional as F
"""
Args:
bins: (..., B + 1).
weights: (..., B).
Returns:
samples: (..., S + 1).
"""
B = weights.shape[-1]
S = num_samples
assert bins.shape[-1] == B + 1
dtype, device = bins.dtype, bins.device
eps = torch.finfo(weights.dtype).eps
# (..., B).
pdf = F.normalize(weights, p=1, dim=-1)
# (..., B + 1).
cdf = torch.cat(
[
torch.zeros_like(pdf[..., :1]),
torch.cumsum(pdf[..., :-1], dim=-1),
torch.ones_like(pdf[..., :1]),
],
dim=-1,
)
# (..., S). Sample positions between [0, 1).
if not stratified:
pad = 1 / (2 * S)
# Get the center of each pdf bins.
u = torch.linspace(pad, 1 - pad - eps, S, dtype=dtype, device=device)
u = u.broadcast_to(bins.shape[:-1] + (S,))
else:
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
u_max = eps + (1 - eps) / S
max_jitter = (1 - u_max) / (S - 1) - eps
# Only perform one jittering per ray (`single_jitter` in the original
# implementation.)
u = (
torch.linspace(0, 1 - u_max, S, dtype=dtype, device=device)
+ torch.rand(
*bins.shape[:-1],
1,
dtype=dtype,
device=device,
)
* max_jitter
)
# (..., S).
ceil = torch.searchsorted(cdf.contiguous(), u.contiguous(), side="right")
floor = ceil - 1
# (..., S * 2).
inds = torch.cat([floor, ceil], dim=-1)
# (..., S).
cdf0, cdf1 = cdf.gather(-1, inds).split(S, dim=-1)
b0, b1 = bins.gather(-1, inds).split(S, dim=-1)
# (..., S). Linear interpolation in 1D.
t = (u - cdf0) / torch.clamp(cdf1 - cdf0, min=eps)
# Sample centers.
centers = b0 + t * (b1 - b0)
samples = (centers[..., 1:] + centers[..., :-1]) / 2
samples = torch.cat(
[
(2 * centers[..., :1] - samples[..., :1]).clamp_min(vmin),
samples,
(2 * centers[..., -1:] - samples[..., -1:]).clamp_max(vmax),
],
dim=-1,
)
return samples, centers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment