Unverified Commit 8340e19d authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

0.5.0: Rewrite all the underlying CUDA. Speedup and Benchmarking. (#182)

* importance_sampling with test

* package importance_sampling

* compute_intervals tested and packaged

* compute_intervals_v2

* bicycle is failing

* fix cut in compute_intervals_v2, test pass for rendering

* hacky way to get opaque_bkgd work

* reorg ING

* PackedRaySegmentsSpec

* chunk_ids -> ray_ids

* binary -> occupied

* test_traverse_grid_basic checked

* fix traverse_grid with step size, checked

* support max_step_size, not verified

* _cuda and cuda; upgrade ray_marching

* inclusive scan

* test_exclusive_sum but seems to have numeric error

* inclusive_sum_backward verified

* exclusive sum backward

* merge fwd and bwd for scan

* inclusive & exclusive prod verified

* support normal scan with torch funcs

* rendering and tests

* a bit clean up

* importance_sampling verified

* stratified for importance_sampling

* importance_sampling in pdf.py

* RaySegmentsSpec in data_specs; fix various bugs

* verified with _proposal_packed.py

* importance sampling support batch input/output. need to verify

* prop script with batch samples

* try to use cumsum  instead of cumprod

* searchsorted

* benchmarking prop

* ray_aabb_intersect untested

* update prop benchmark numbers

* minor fixes

* batched ray_aabb_intersect

* ray_aabb_intersect and traverse with grid(s)

* tiny optimize for traverse_grids kernels

* traverse_grids return intervals and samples

* cub not verified

* cleanup

* propnet and occgrid as estimators

* training print iters 10k

* prop is good now

* benchmark in google sheet.

* really cleanup: scan.py and test

* pack.py and test

* rendering and test

* data_specs.py and pdf.py docs

* data_specs.py and pdf.py docs

* init and headers

* grid.py and test for it

* occ grid docs

* generated docs

* example docs for pack and scan function.

* doc fix for volrend.py

* doc fix for pdf.py

* fix doc for rendering function

* docs

* propnet docs

* update scripts

* docs: index.rst

* methodology docs

* docs for examples

* mlp nerf script

* update t-nerf script

* rename dnerf to tnerf

* misc update

* bug fix: pdf_loss with test

* minor fix

* update readme with submodules

* fix format

* update gitingore file

* fix doc failure. teaser png to jpg

* docs in examples/
parent e547490c
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*
* Modified from
* https://github.com/pytorch/pytorch/blob/06a64f7eaa47ce430a3fa61016010075b59b18a7/aten/src/ATen/native/cuda/ScanUtils.cuh
*/
#include "utils_cuda.cuh"
// CUB support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
namespace {
namespace device {
/* Perform an inclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void inclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size) {
row_tgt[col] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
inclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
inclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
/* Perform an exclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void exclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
row_tgt[0] = init;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size - 1) row_tgt[col1 + 1] = row_buf[threadIdx.x];
if (col2 < row_size - 1) row_tgt[col2 + 1] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size - 1) {
row_tgt[col + 1] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
exclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
exclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
} // namespace device
} // namespace
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *near,
scalar_t *far)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax)
_swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax)
_swap(tymin, tymax);
if (tmin > tymax || tymin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tymin > tmin)
tmin = tymin;
if (tymax < tmax)
tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax)
_swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tzmin > tmin)
tmin = tzmin;
if (tzmax < tmax)
tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void ray_aabb_intersect_kernel(
const int N,
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *t_min,
scalar_t *t_max)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
// locate
rays_o += thread_id * 3;
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
}
/**
* @brief Ray AABB Test
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero. 1e10 is returned if no intersection.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(aabb);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(N, threads);
torch::Tensor t_min = torch::empty({N}, rays_o.options());
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&]
{ ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()); }));
return {t_min, t_max};
}
\ No newline at end of file
// This file contains only Python bindings
#include "include/data_spec.hpp"
#include <torch/extension.h>
bool is_cub_available() {
// FIXME: why return false?
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
// scan
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
// grid
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor aabbs, // [n_aabbs, 6]
const float near_plane,
const float far_plane,
const float miss_value);
std::vector<RaySegmentsSpec> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples);
// pdf
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
torch::Tensor n_intervels_per_ray,
bool stratified);
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
int64_t n_intervels_per_ray,
bool stratified);
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function
_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);
#undef _REG_FUNC
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void unpack_info_kernel(
// input
const int n_rays,
const int *packed_info,
// output
int64_t *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
ray_indices += base;
for (int j = 0; j < steps; ++j)
{
ray_indices[j] = i;
}
}
__global__ void unpack_info_to_mask_kernel(
// input
const int n_rays,
const int *packed_info,
const int n_samples,
// output
bool *masks) // [n_rays, n_samples]
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
masks += i * n_samples;
for (int j = 0; j < steps; ++j)
{
masks[j] = true;
}
}
template <typename scalar_t>
__global__ void unpack_data_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const int data_dim,
const scalar_t *data,
const int n_sampler_per_ray,
scalar_t *unpacked_data) // (n_rays, n_sampler_per_ray, data_dim)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
data += base * data_dim;
unpacked_data += i * n_sampler_per_ray * data_dim;
for (int j = 0; j < steps; j++)
{
for (int k = 0; k < data_dim; k++)
{
unpacked_data[j * data_dim + k] = data[j * data_dim + k];
}
}
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int64_t>());
return ray_indices;
}
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor masks = torch::zeros(
{n_rays, n_samples}, packed_info.options().dtype(torch::kBool));
unpack_info_to_mask_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
n_samples,
masks.data_ptr<bool>());
return masks;
}
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(data);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(data.ndimension() == 2);
const int n_rays = packed_info.size(0);
const int n_samples = data.size(0);
const int data_dim = data.size(1);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::full(
{n_rays, n_samples_per_ray, data_dim}, pad_value, data.options());
AT_DISPATCH_ALL_TYPES(
data.scalar_type(),
"unpack_data",
([&]
{ unpack_data_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
data_dim,
data.data_ptr<scalar_t>(),
n_samples_per_ray,
// outputs
unpacked_data.data_ptr<scalar_t>()); }));
return unpacked_data;
}
This diff is collapsed.
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type);
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value);
// cub implementations: parallel across samples
bool is_cub_available() {
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
// naive implementations: parallel across rays
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas);
// pdf
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt); // [n_rays]
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt); // [n_rays, n_samples_out]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// contraction
py::enum_<ContractionType>(m, "ContractionType")
.value("AABB", ContractionType::AABB)
.value("UN_BOUNDED_TANH", ContractionType::UN_BOUNDED_TANH)
.value("UN_BOUNDED_SPHERE", ContractionType::UN_BOUNDED_SPHERE);
m.def("contract", &contract);
m.def("contract_inv", &contract_inv);
// grid
m.def("grid_query", &grid_query);
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
// rendering
m.def("is_cub_available", is_cub_available);
m.def("transmittance_from_sigma_forward_cub", transmittance_from_sigma_forward_cub);
m.def("transmittance_from_sigma_backward_cub", transmittance_from_sigma_backward_cub);
m.def("transmittance_from_alpha_forward_cub", transmittance_from_alpha_forward_cub);
m.def("transmittance_from_alpha_backward_cub", transmittance_from_alpha_backward_cub);
m.def("transmittance_from_sigma_forward_naive", transmittance_from_sigma_forward_naive);
m.def("transmittance_from_sigma_backward_naive", transmittance_from_sigma_backward_naive);
m.def("transmittance_from_alpha_forward_naive", transmittance_from_alpha_forward_naive);
m.def("transmittance_from_alpha_backward_naive", transmittance_from_alpha_backward_naive);
m.def("weight_from_sigma_forward_naive", weight_from_sigma_forward_naive);
m.def("weight_from_sigma_backward_naive", weight_from_sigma_backward_naive);
m.def("weight_from_alpha_forward_naive", weight_from_alpha_forward_naive);
m.def("weight_from_alpha_backward_naive", weight_from_alpha_backward_naive);
// pack & unpack
m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask);
// pdf
m.def("pdf_sampling", &pdf_sampling);
m.def("pdf_readout", &pdf_readout);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
inline __device__ __host__ float calc_dt(
const float t, const float cone_angle,
const float dt_min, const float dt_max)
{
return clamp(t * cone_angle, dt_min, dt_max);
}
inline __device__ __host__ int mip_level(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
const ContractionType type)
{
if (type != ContractionType::AABB)
{
// mip level should be always zero if not using AABB
return 0;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, ContractionType::AABB);
float3 scale = fabs(xyz_unit - 0.5);
float maxval = fmaxf(fmaxf(scale.x, scale.y), scale.z);
// if maxval is almost zero, it will trigger frexpf to output 0
// for exponent, which is not what we want.
maxval = fmaxf(maxval, 0.1);
int exponent;
frexpf(maxval, &exponent);
int mip = max(0, exponent + 1);
return mip;
}
inline __device__ __host__ int grid_idx_at(
const float3 xyz_unit, const int3 grid_res)
{
// xyz should be always in [0, 1]^3.
int3 ixyz = make_int3(xyz_unit * make_float3(grid_res));
ixyz = clamp(ixyz, make_int3(0, 0, 0), grid_res - 1);
int3 grid_offset = make_int3(grid_res.y * grid_res.z, grid_res.z, 1);
int idx = dot(ixyz, grid_offset);
return idx;
}
template <typename scalar_t>
inline __device__ __host__ scalar_t grid_occupied_at(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
ContractionType type, int mip,
const int grid_nlvl, const int3 grid_res, const scalar_t *grid_value)
{
if (type == ContractionType::AABB && mip >= grid_nlvl)
{
return false;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type);
xyz_unit = (xyz_unit - 0.5) * scalbnf(1.0f, -mip) + 0.5;
int idx = grid_idx_at(xyz_unit, grid_res) + mip * grid_res.x * grid_res.y * grid_res.z;
return grid_value[idx];
}
// dda like step
inline __device__ __host__ float distance_to_next_voxel(
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res)
{
float scaling = scalbnf(1.0f, mip);
float3 _roi_mid = (roi_min + roi_max) * 0.5;
float3 _roi_rad = (roi_max - roi_min) * 0.5;
float3 _roi_min = _roi_mid - _roi_rad * scaling;
float3 _roi_max = _roi_mid + _roi_rad * scaling;
float3 _occ_res = make_float3(grid_res);
float3 _xyz = roi_to_unit(xyz, _roi_min, _roi_max) * _occ_res;
float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (_roi_max - _roi_min);
float t = min(min(txyz.x, txyz.y), txyz.z);
return fmaxf(t, 0.0f);
}
inline __device__ __host__ float advance_to_next_voxel(
const float t, const float dt_min,
const float3 xyz, const float3 dir, const float3 inv_dir, int mip,
const float3 roi_min, const float3 roi_max, const int3 grid_res, const float far)
{
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res);
t_target = min(t_target, far);
float _t = t;
do
{
_t += dt_min;
} while (_t < t_target);
return _t;
}
// -------------------------------------------------------------------------------
// Raymarching
// -------------------------------------------------------------------------------
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int64_t *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float grid_cell_scale = fmaxf(fmaxf(
(roi_max.x - roi_min.x) / grid_res.x * scalbnf(1.732f, grid_nlvl - 1),
(roi_max.y - roi_min.y) / grid_res.y * scalbnf(1.732f, grid_nlvl - 1)),
(roi_max.z - roi_min.z) / grid_res.z * scalbnf(1.732f, grid_nlvl - 1));
// TODO: compute dt_max from occ resolution.
float dt_min = step_size;
float dt_max;
if (type == ContractionType::AABB) {
// compute dt_max from occ grid resolution.
dt_max = grid_cell_scale;
} else {
dt_max = 1e10f;
}
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
// current center
const float3 xyz = origin + t_mid * dir;
// current mip level
const int mip = mip_level(xyz, roi_min, roi_max, type);
if (mip >= grid_nlvl) {
// out of grid
break;
}
if (grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_binary))
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
else
{
// march to next sample
switch (type)
{
case ContractionType::AABB:
// no contraction
t_mid = advance_to_next_voxel(
t_mid, dt_min, xyz, dir, inv_dir, mip, roi_min, roi_max, grid_res, far);
dt = calc_dt(t_mid, cone_angle, dt_min, dt_max);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
break;
default:
// any type of scene contraction does not work with DDA.
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
break;
}
}
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(roi);
CHECK_INPUT(grid_binary);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6)
TORCH_CHECK(grid_binary.ndimension() == 4)
const int n_rays = rays_o.size(0);
const int grid_nlvl = grid_binary.size(0);
const int3 grid_res = make_int3(
grid_binary.size(1), grid_binary.size(2), grid_binary.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options().dtype(torch::kLong));
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int64_t>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
// ----------------------------------------------------------------------------
// Query the occupancy grid
// ----------------------------------------------------------------------------
template <typename scalar_t>
__global__ void query_occ_kernel(
// rays info
const uint32_t n_samples,
const float *samples, // shape (n_samples, 3)
// occupancy grid & contraction
const float *roi,
const int grid_nlvl,
const int3 grid_res,
const scalar_t *grid_value, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// outputs
scalar_t *occs)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
const int mip = mip_level(xyz, roi_min, roi_max, type);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, mip, grid_nlvl, grid_res, grid_value);
return;
}
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int grid_nlvl = grid_value.size(0);
const int3 grid_res = make_int3(
grid_value.size(1), grid_value.size(2), grid_value.size(3));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::empty({n_samples}, grid_value.options());
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Bool,
occs.scalar_type(),
"grid_query",
([&]
{ query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// grid
roi.data_ptr<float>(),
grid_nlvl,
grid_res,
grid_value.data_ptr<scalar_t>(),
type,
// outputs
occs.data_ptr<scalar_t>()); }));
return occs;
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment