Unverified Commit 8340e19d authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

0.5.0: Rewrite all the underlying CUDA. Speedup and Benchmarking. (#182)

* importance_sampling with test

* package importance_sampling

* compute_intervals tested and packaged

* compute_intervals_v2

* bicycle is failing

* fix cut in compute_intervals_v2, test pass for rendering

* hacky way to get opaque_bkgd work

* reorg ING

* PackedRaySegmentsSpec

* chunk_ids -> ray_ids

* binary -> occupied

* test_traverse_grid_basic checked

* fix traverse_grid with step size, checked

* support max_step_size, not verified

* _cuda and cuda; upgrade ray_marching

* inclusive scan

* test_exclusive_sum but seems to have numeric error

* inclusive_sum_backward verified

* exclusive sum backward

* merge fwd and bwd for scan

* inclusive & exclusive prod verified

* support normal scan with torch funcs

* rendering and tests

* a bit clean up

* importance_sampling verified

* stratified for importance_sampling

* importance_sampling in pdf.py

* RaySegmentsSpec in data_specs; fix various bugs

* verified with _proposal_packed.py

* importance sampling support batch input/output. need to verify

* prop script with batch samples

* try to use cumsum  instead of cumprod

* searchsorted

* benchmarking prop

* ray_aabb_intersect untested

* update prop benchmark numbers

* minor fixes

* batched ray_aabb_intersect

* ray_aabb_intersect and traverse with grid(s)

* tiny optimize for traverse_grids kernels

* traverse_grids return intervals and samples

* cub not verified

* cleanup

* propnet and occgrid as estimators

* training print iters 10k

* prop is good now

* benchmark in google sheet.

* really cleanup: scan.py and test

* pack.py and test

* rendering and test

* data_specs.py and pdf.py docs

* data_specs.py and pdf.py docs

* init and headers

* grid.py and test for it

* occ grid docs

* generated docs

* example docs for pack and scan function.

* doc fix for volrend.py

* doc fix for pdf.py

* fix doc for rendering function

* docs

* propnet docs

* update scripts

* docs: index.rst

* methodology docs

* docs for examples

* mlp nerf script

* update t-nerf script

* rename dnerf to tnerf

* misc update

* bug fix: pdf_loss with test

* minor fix

* update readme with submodules

* fix format

* update gitingore file

* fix doc failure. teaser png to jpg

* docs in examples/
parent e547490c
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*
* Modified from
* https://github.com/pytorch/pytorch/blob/06a64f7eaa47ce430a3fa61016010075b59b18a7/aten/src/ATen/native/cuda/ScanUtils.cuh
*/
#include "utils_cuda.cuh"
// CUB support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
namespace {
namespace device {
/* Perform an inclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void inclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size) {
row_tgt[col] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
inclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
inclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
/* Perform an exclusive scan for a flattened tensor.
*
* - num_rows is the size of the outer dimensions;
* - {chunk_starts, chunk_cnts} defines the regions of the flattened tensor to be scanned.
*
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__device__ void exclusive_scan_impl(
T* row_buf, DataIteratorT tgt_, DataIteratorT src_,
const uint32_t num_rows,
// const uint32_t row_size,
IdxIteratorT chunk_starts, IdxIteratorT chunk_cnts,
T init, BinaryFunction binary_op,
bool normalize = false){
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
uint32_t row = block_row + threadIdx.y;
T block_total = init;
if (row >= num_rows) continue;
DataIteratorT row_src = src_ + chunk_starts[row];
DataIteratorT row_tgt = tgt_ + chunk_starts[row];
uint32_t row_size = chunk_cnts[row];
if (row_size == 0) continue;
row_tgt[0] = init;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size - 1) row_tgt[col1 + 1] = row_buf[threadIdx.x];
if (col2 < row_size - 1) row_tgt[col2 + 1] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
// Normalize with the last value: should only be used by scan_sum
if (normalize) {
for (uint32_t block_col = 0; block_col < row_size; block_col += num_threads_x)
{
uint32_t col = block_col + threadIdx.x;
if (col < row_size - 1) {
row_tgt[col + 1] /= fmaxf(block_total, 1e-10f);
}
}
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction,
typename DataIteratorT,
typename IdxIteratorT>
__global__ void
exclusive_scan_kernel(
DataIteratorT tgt_,
DataIteratorT src_,
const uint32_t num_rows,
IdxIteratorT chunk_starts,
IdxIteratorT chunk_cnts,
T init,
BinaryFunction binary_op,
bool normalize = false) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
exclusive_scan_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, chunk_starts, chunk_cnts, init, binary_op, normalize);
}
} // namespace device
} // namespace
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *near,
scalar_t *far)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax)
_swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax)
_swap(tymin, tymax);
if (tmin > tymax || tymin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tymin > tmin)
tmin = tymin;
if (tymax < tmax)
tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax)
_swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax)
{
*near = 1e10;
*far = 1e10;
return;
}
if (tzmin > tmin)
tmin = tzmin;
if (tzmax < tmax)
tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void ray_aabb_intersect_kernel(
const int N,
const scalar_t *rays_o,
const scalar_t *rays_d,
const scalar_t *aabb,
scalar_t *t_min,
scalar_t *t_max)
{
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
// locate
rays_o += thread_id * 3;
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
}
/**
* @brief Ray AABB Test
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero. 1e10 is returned if no intersection.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(aabb);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(N, threads);
torch::Tensor t_min = torch::empty({N}, rays_o.options());
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&]
{ ray_aabb_intersect_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()); }));
return {t_min, t_max};
}
\ No newline at end of file
// This file contains only Python bindings
#include "include/data_spec.hpp"
#include <torch/extension.h>
bool is_cub_available() {
// FIXME: why return false?
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
// scan
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor exclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
bool normalize,
bool backward);
torch::Tensor inclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor inclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_forward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs);
torch::Tensor exclusive_prod_backward(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
// grid
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor aabbs, // [n_aabbs, 6]
const float near_plane,
const float far_plane,
const float miss_value);
std::vector<RaySegmentsSpec> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples);
// pdf
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
torch::Tensor n_intervels_per_ray,
bool stratified);
std::vector<RaySegmentsSpec> importance_sampling(
RaySegmentsSpec ray_segments,
torch::Tensor cdfs,
int64_t n_intervels_per_ray,
bool stratified);
std::vector<torch::Tensor> searchsorted(
RaySegmentsSpec query,
RaySegmentsSpec key);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function
_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
_REG_FUNC(inclusive_prod_backward);
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted);
#undef _REG_FUNC
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
}
\ No newline at end of file
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void unpack_info_kernel(
// input
const int n_rays,
const int *packed_info,
// output
int64_t *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
ray_indices += base;
for (int j = 0; j < steps; ++j)
{
ray_indices[j] = i;
}
}
__global__ void unpack_info_to_mask_kernel(
// input
const int n_rays,
const int *packed_info,
const int n_samples,
// output
bool *masks) // [n_rays, n_samples]
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
masks += i * n_samples;
for (int j = 0; j < steps; ++j)
{
masks[j] = true;
}
}
template <typename scalar_t>
__global__ void unpack_data_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const int data_dim,
const scalar_t *data,
const int n_sampler_per_ray,
scalar_t *unpacked_data) // (n_rays, n_sampler_per_ray, data_dim)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
data += base * data_dim;
unpacked_data += i * n_sampler_per_ray * data_dim;
for (int j = 0; j < steps; j++)
{
for (int k = 0; k < data_dim; k++)
{
unpacked_data[j * data_dim + k] = data[j * data_dim + k];
}
}
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int64_t>());
return ray_indices;
}
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor masks = torch::zeros(
{n_rays, n_samples}, packed_info.options().dtype(torch::kBool));
unpack_info_to_mask_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
n_samples,
masks.data_ptr<bool>());
return masks;
}
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(data);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(data.ndimension() == 2);
const int n_rays = packed_info.size(0);
const int n_samples = data.size(0);
const int data_dim = data.size(1);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor unpacked_data = torch::full(
{n_rays, n_samples_per_ray, data_dim}, pad_value, data.options());
AT_DISPATCH_ALL_TYPES(
data.scalar_type(),
"unpack_data",
([&]
{ unpack_data_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
data_dim,
data.data_ptr<scalar_t>(),
n_samples_per_ray,
// outputs
unpacked_data.data_ptr<scalar_t>()); }));
return unpacked_data;
}
This diff is collapsed.
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
torch::Tensor grid_query(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_value,
const ContractionType type);
torch::Tensor contract(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor contract_inv(
const torch::Tensor samples,
// contraction
const torch::Tensor roi,
const ContractionType type);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
int n_samples_per_ray,
float pad_value);
// cub implementations: parallel across samples
bool is_cub_available() {
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor ray_indices,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor ray_indices, torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor ray_indices,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
// naive implementations: parallel across rays
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad);
torch::Tensor weight_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_sigma_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas);
torch::Tensor weight_from_alpha_forward_naive(
torch::Tensor packed_info,
torch::Tensor alphas);
torch::Tensor weight_from_alpha_backward_naive(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas);
// pdf
torch::Tensor pdf_sampling(
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_samples_in - 1]
int64_t n_samples, // n_samples_out
float padding,
bool stratified,
bool single_jitter,
c10::optional<torch::Tensor> masks_opt); // [n_rays]
torch::Tensor pdf_readout(
// keys
torch::Tensor ts, // [n_rays, n_samples_in]
torch::Tensor weights, // [n_rays, n_bins_in]
c10::optional<torch::Tensor> masks_opt, // [n_rays]
// query
torch::Tensor ts_out,
c10::optional<torch::Tensor> masks_out_opt); // [n_rays, n_samples_out]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
// contraction
py::enum_<ContractionType>(m, "ContractionType")
.value("AABB", ContractionType::AABB)
.value("UN_BOUNDED_TANH", ContractionType::UN_BOUNDED_TANH)
.value("UN_BOUNDED_SPHERE", ContractionType::UN_BOUNDED_SPHERE);
m.def("contract", &contract);
m.def("contract_inv", &contract_inv);
// grid
m.def("grid_query", &grid_query);
// marching
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
// rendering
m.def("is_cub_available", is_cub_available);
m.def("transmittance_from_sigma_forward_cub", transmittance_from_sigma_forward_cub);
m.def("transmittance_from_sigma_backward_cub", transmittance_from_sigma_backward_cub);
m.def("transmittance_from_alpha_forward_cub", transmittance_from_alpha_forward_cub);
m.def("transmittance_from_alpha_backward_cub", transmittance_from_alpha_backward_cub);
m.def("transmittance_from_sigma_forward_naive", transmittance_from_sigma_forward_naive);
m.def("transmittance_from_sigma_backward_naive", transmittance_from_sigma_backward_naive);
m.def("transmittance_from_alpha_forward_naive", transmittance_from_alpha_forward_naive);
m.def("transmittance_from_alpha_backward_naive", transmittance_from_alpha_backward_naive);
m.def("weight_from_sigma_forward_naive", weight_from_sigma_forward_naive);
m.def("weight_from_sigma_backward_naive", weight_from_sigma_backward_naive);
m.def("weight_from_alpha_forward_naive", weight_from_alpha_forward_naive);
m.def("weight_from_alpha_backward_naive", weight_from_alpha_backward_naive);
// pack & unpack
m.def("unpack_data", &unpack_data);
m.def("unpack_info", &unpack_info);
m.def("unpack_info_to_mask", &unpack_info_to_mask);
// pdf
m.def("pdf_sampling", &pdf_sampling);
m.def("pdf_readout", &pdf_readout);
}
\ No newline at end of file
This diff is collapsed.
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include "include/helpers_cuda.h"
__global__ void transmittance_from_sigma_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *sigmas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
starts += base;
ends += base;
sigmas += base;
transmittance += base;
// accumulation
float cumsum = 0.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = __expf(-cumsum);
cumsum += sigmas[j] * (ends[j] - starts[j]);
}
// // another way to impl:
// float T = 1.f;
// for (int j = 0; j < steps; ++j)
// {
// const float delta = ends[j] - starts[j];
// const float alpha = 1.f - __expf(-sigmas[j] * delta);
// transmittance[j] = T;
// T *= (1.f - alpha);
// }
return;
}
__global__ void transmittance_from_sigma_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *starts,
const float *ends,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *sigmas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
transmittance += base;
transmittance_grad += base;
starts += base;
ends += base;
sigmas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
sigmas_grad[j] = cumsum * (ends[j] - starts[j]);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
__global__ void transmittance_from_alpha_forward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
// outputs
float *transmittance)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
// accumulation
float T = 1.0f;
for (int j = 0; j < steps; ++j)
{
transmittance[j] = T;
T *= (1.0f - alphas[j]);
}
return;
}
__global__ void transmittance_from_alpha_backward_kernel(
const uint32_t n_rays,
// inputs
const int *packed_info,
const float *alphas,
const float *transmittance,
const float *transmittance_grad,
// outputs
float *alphas_grad)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0];
const int steps = packed_info[i * 2 + 1];
if (steps == 0)
return;
alphas += base;
transmittance += base;
transmittance_grad += base;
alphas_grad += base;
// accumulation
float cumsum = 0.0f;
for (int j = steps - 1; j >= 0; --j)
{
alphas_grad[j] = cumsum / fmax(1.0f - alphas[j], 1e-10f);
cumsum += -transmittance_grad[j] * transmittance[j];
}
return;
}
torch::Tensor transmittance_from_sigma_forward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_samples = sigmas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(sigmas);
// parallel across rays
transmittance_from_sigma_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
sigmas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_sigma_backward_naive(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor sigmas_grad = torch::empty_like(transmittance);
// parallel across rays
transmittance_from_sigma_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<float>(),
ends.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
sigmas_grad.data_ptr<float>());
return sigmas_grad;
}
torch::Tensor transmittance_from_alpha_forward_naive(
torch::Tensor packed_info, torch::Tensor alphas)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
TORCH_CHECK(packed_info.ndimension() == 2);
const uint32_t n_samples = alphas.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor transmittance = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_forward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
// outputs
transmittance.data_ptr<float>());
return transmittance;
}
torch::Tensor transmittance_from_alpha_backward_naive(
torch::Tensor packed_info,
torch::Tensor alphas,
torch::Tensor transmittance,
torch::Tensor transmittance_grad)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(transmittance);
CHECK_INPUT(transmittance_grad);
TORCH_CHECK(packed_info.ndimension() == 2);
TORCH_CHECK(transmittance.ndimension() == 2 & transmittance.size(1) == 1);
TORCH_CHECK(transmittance_grad.ndimension() == 2 & transmittance_grad.size(1) == 1);
const uint32_t n_samples = transmittance.size(0);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor alphas_grad = torch::empty_like(alphas);
// parallel across rays
transmittance_from_alpha_backward_kernel<<<
blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
alphas.data_ptr<float>(),
transmittance.data_ptr<float>(),
transmittance_grad.data_ptr<float>(),
// outputs
alphas_grad.data_ptr<float>());
return alphas_grad;
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from typing import Any
import torch
import torch.nn as nn
class AbstractEstimator(nn.Module):
"""An abstract Transmittance Estimator class for Sampling."""
def __init__(self) -> None:
super().__init__()
self.register_buffer("_dummy", torch.empty(0), persistent=False)
@property
def device(self) -> torch.device:
return self._dummy.device
def sampling(self, *args, **kwargs) -> Any:
raise NotImplementedError
def update_every_n_steps(self, *args, **kwargs) -> None:
raise NotImplementedError
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment