Unverified Commit 8dcfbad9 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Reformat (#31)



* seems working

* contraction func in cuda

* Update type

* More type updates

* disable DDA for contraction

* update contraction perfom in readme

* 360 data: Garden

* eval at max_steps

* add perform of 360 to readme

* fix contraction scaling

* tiny hot fix

* new volrend

* cleanup ray_marching.cu

* cleanup backend

* tests

* cleaning up Grid

* fix doc for grid base class

* check and fix for contraction

* test grid

* rendering and marching

* transmittance_compress verified

* rendering is indeed faster

* pipeline is working

* lego example

* cleanup

* cuda folder is cleaned up! finally!

* cuda formatting

* contraction verify

* upgrade grid

* test for ray marching

* pipeline

* ngp with contraction

* train_ngp runs but slow

* trasmittance seperate to two. Now NGP is as fast as before

* verified faster than before

* bug fix for contraction

* ngp contraction fix

* tiny cleanup

* contraction works! yay!

* contraction with tanh seems working

* minor update

* support alpha rendering

* absorb visibility to ray marching

* tiny import update

* get rid of contraction temperture;

* doc for ContractionType

* doc for Grid

* doc for grid.py is done

* doc for ray marching

* rendering function

* fix doc for rendering

* doc for vol rend

* autosummary for utils

* fix autosummary line break

* utils docs

* api doc is done

* starting work on examples

* contraction for npg is in python now

* further clean up examples

* mlp nerf is running

* dnerf is in

* update readme command

* merge

* disable pylint error for now

* reformatting and skip tests without cuda

* fix the type issue for contractiontype

* fix cuda attribute issue

* bump to 0.1.0
Co-authored-by: default avatarMatt Tancik <tancik@berkeley.edu>
parent a7611603
#include "include/helpers_cuda.h"
#include "include/helpers_math.h"
#include "include/helpers_contraction.h"
inline __device__ __host__ float calc_dt(
const float t, const float cone_angle,
const float dt_min, const float dt_max)
{
return clamp(t * cone_angle, dt_min, dt_max);
}
inline __device__ __host__ int grid_idx_at(
const float3 xyz_unit, const int3 grid_res)
{
// xyz should be always in [0, 1]^3.
int3 ixyz = make_int3(xyz_unit * make_float3(grid_res));
ixyz = clamp(ixyz, make_int3(0, 0, 0), grid_res - 1);
int3 grid_offset = make_int3(grid_res.y * grid_res.z, grid_res.z, 1);
int idx = dot(ixyz, grid_offset);
return idx;
}
inline __device__ __host__ bool grid_occupied_at(
const float3 xyz,
const float3 roi_min, const float3 roi_max,
ContractionType type,
const int3 grid_res, const bool *grid_binary)
{
if (type == ContractionType::AABB &&
(xyz.x < roi_min.x || xyz.x > roi_max.x ||
xyz.y < roi_min.y || xyz.y > roi_max.y ||
xyz.z < roi_min.z || xyz.z > roi_max.z))
{
return false;
}
float3 xyz_unit = apply_contraction(
xyz, roi_min, roi_max, type);
int idx = grid_idx_at(xyz_unit, grid_res);
return grid_binary[idx];
}
// dda like step
inline __device__ __host__ float distance_to_next_voxel(
const float3 xyz, const float3 dir, const float3 inv_dir,
const float3 roi_min, const float3 roi_max, const int3 grid_res)
{
float3 _occ_res = make_float3(grid_res);
float3 _xyz = roi_to_unit(xyz, roi_min, roi_max) * _occ_res;
float3 txyz = ((floorf(_xyz + 0.5f + 0.5f * sign(dir)) - _xyz) * inv_dir) / _occ_res * (roi_max - roi_min);
float t = min(min(txyz.x, txyz.y), txyz.z);
return fmaxf(t, 0.0f);
}
inline __device__ __host__ float advance_to_next_voxel(
const float t, const float dt_min,
const float3 xyz, const float3 dir, const float3 inv_dir,
const float3 roi_min, const float3 roi_max, const int3 grid_res)
{
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
xyz, dir, inv_dir, roi_min, roi_max, grid_res);
float _t = t;
do
{
_t += dt_min;
} while (_t < t_target);
return _t;
}
// -------------------------------------------------------------------------------
// Raymarching
// -------------------------------------------------------------------------------
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// occupancy grid & contraction
const float *roi,
const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
// TODO: compute dt_max from occ resolution.
float dt_min = step_size;
float dt_max = 1e10f;
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
// current center
const float3 xyz = origin + t_mid * dir;
if (grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_binary))
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
else
{
// march to next sample
switch (type)
{
case ContractionType::AABB:
// no contraction
t_mid = advance_to_next_voxel(
t_mid, dt_min, xyz, dir, inv_dir, roi_min, roi_max, grid_res);
dt = calc_dt(t_mid, cone_angle, dt_min, dt_max);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
break;
default:
// any type of scene contraction does not work with DDA.
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
break;
}
}
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(roi);
CHECK_INPUT(grid_binary);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
TORCH_CHECK(roi.ndimension() == 1 & roi.size(0) == 6)
TORCH_CHECK(grid_binary.ndimension() == 3)
const int n_rays = rays_o.size(0);
const int3 grid_res = make_int3(
grid_binary.size(0), grid_binary.size(1), grid_binary.size(2));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::zeros(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::zeros({total_steps, 1}, rays_o.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// occupancy grid & contraction
roi.data_ptr<float>(),
grid_res,
grid_binary.data_ptr<bool>(),
type,
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, t_starts, t_ends};
}
// -----------------------------------------------------------------------------
// Ray index for each sample
// -----------------------------------------------------------------------------
__global__ void ray_indices_kernel(
// input
const int n_rays,
const int *packed_info,
// output
int *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
ray_indices += base;
for (int j = 0; j < steps; ++j)
{
ray_indices[j] = i;
}
}
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::zeros(
{n_samples}, packed_info.options().dtype(torch::kInt32));
ray_indices_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>());
return ray_indices;
}
// ----------------------------------------------------------------------------
// Query the occupancy grid
// ----------------------------------------------------------------------------
__global__ void query_occ_kernel(
// rays info
const uint32_t n_samples,
const float *samples, // shape (n_samples, 3)
// occupancy grid & contraction
const float *roi,
const int3 grid_res,
const bool *grid_binary, // shape (reso_x, reso_y, reso_z)
const ContractionType type,
// outputs
bool *occs)
{
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
const float3 roi_min = make_float3(roi[0], roi[1], roi[2]);
const float3 roi_max = make_float3(roi[3], roi[4], roi[5]);
const float3 xyz = make_float3(samples[0], samples[1], samples[2]);
*occs = grid_occupied_at(xyz, roi_min, roi_max, type, grid_res, grid_binary);
return;
}
torch::Tensor query_occ(
const torch::Tensor samples,
// occupancy grid & contraction
const torch::Tensor roi,
const torch::Tensor grid_binary,
const ContractionType type)
{
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int3 grid_res = make_int3(
grid_binary.size(0), grid_binary.size(1), grid_binary.size(2));
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::zeros(
{n_samples}, samples.options().dtype(torch::kBool));
query_occ_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// grid
roi.data_ptr<float>(),
grid_res,
grid_binary.data_ptr<bool>(),
type,
// outputs
occs.data_ptr<bool>());
return occs;
}
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void rendering_forward_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
// outputs: should be all-zero initialized
int *num_steps, // the number of valid steps for each ray
scalar_t *weights, // the number rendering weights for each sample
bool *compact_selector // the samples that we needs to compute the gradients
)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
if (alphas != nullptr)
{
// rendering with alpha
alphas += base;
}
else
{
// rendering with density
starts += base;
ends += base;
sigmas += base;
}
if (num_steps != nullptr)
{
num_steps += i;
}
if (weights != nullptr)
{
weights += base;
}
if (compact_selector != nullptr)
{
compact_selector += base;
}
// accumulated rendering
scalar_t T = 1.f;
int j = 0;
for (; j < steps; ++j)
{
if (T < early_stop_eps)
{
break;
}
scalar_t alpha;
if (alphas != nullptr)
{
// rendering with alpha
alpha = alphas[j];
}
else
{
// rendering with density
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
}
const scalar_t weight = alpha * T;
T *= (1.f - alpha);
if (weights != nullptr)
{
weights[j] = weight;
}
if (compact_selector != nullptr)
{
compact_selector[j] = true;
}
}
if (num_steps != nullptr)
{
*num_steps = j;
}
return;
}
template <typename scalar_t>
__global__ void rendering_backward_kernel(
const uint32_t n_rays,
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t *weights, // forward output
const scalar_t *grad_weights, // input gradients
// if alphas was given, we compute the gradients for alphas.
// otherwise, we compute the gradients for sigmas.
scalar_t *grad_sigmas, // output gradients
scalar_t *grad_alphas // output gradients
)
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0)
return;
if (alphas != nullptr)
{
// rendering with alpha
alphas += base;
grad_alphas += base;
}
else
{
// rendering with density
starts += base;
ends += base;
sigmas += base;
grad_sigmas += base;
}
weights += base;
grad_weights += base;
scalar_t accum = 0;
for (int j = 0; j < steps; ++j)
{
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
for (int j = 0; j < steps; ++j)
{
if (T < early_stop_eps)
{
break;
}
scalar_t alpha;
if (alphas != nullptr)
{
// rendering with alpha
alpha = alphas[j];
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
}
else
{
// rendering with density
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
}
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
}
std::vector<torch::Tensor> rendering_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
bool compression)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
if (compression)
{
// compress the samples to get rid of invisible ones.
torch::Tensor num_steps = torch::zeros({n_rays}, packed_info.options());
torch::Tensor compact_selector = torch::zeros(
{n_samples}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"rendering_forward",
([&]
{ rendering_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
// outputs
num_steps.data_ptr<int>(),
nullptr,
compact_selector.data_ptr<bool>()); }));
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor compact_packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
return {compact_packed_info, compact_selector};
}
else
{
// just do the forward rendering.
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"rendering_forward",
([&]
{ rendering_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
nullptr); }));
return {weights};
}
}
torch::Tensor rendering_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"rendering_backward",
([&]
{ rendering_backward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
grad_sigmas.data_ptr<scalar_t>(),
nullptr // alphas gradients
); }));
return grad_sigmas;
}
// -- rendering with alphas -- //
std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
bool compression)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(alphas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(alphas.ndimension() == 2 & alphas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = alphas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
if (compression)
{
// compress the samples to get rid of invisible ones.
torch::Tensor num_steps = torch::zeros({n_rays}, packed_info.options());
torch::Tensor compact_selector = torch::zeros(
{n_samples}, alphas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
alphas.scalar_type(),
"rendering_alphas_forward",
([&]
{ rendering_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
nullptr, // starts
nullptr, // ends
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
// outputs
num_steps.data_ptr<int>(),
nullptr,
compact_selector.data_ptr<bool>()); }));
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor compact_packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
return {compact_selector, compact_packed_info};
}
else
{
// just do the forward rendering.
torch::Tensor weights = torch::zeros({n_samples}, alphas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
alphas.scalar_type(),
"rendering_forward",
([&]
{ rendering_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
nullptr, // starts
nullptr, // ends
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
nullptr); }));
return {weights};
}
}
torch::Tensor rendering_alphas_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = alphas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_alphas = torch::zeros(alphas.sizes(), alphas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
alphas.scalar_type(),
"rendering_alphas_backward",
([&]
{ rendering_backward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
nullptr, // starts
nullptr, // ends
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
nullptr, // sigma gradients
grad_alphas.data_ptr<scalar_t>()); }));
return grad_alphas;
}
#include <pybind11/pybind11.h>
#include "include/helpers_cuda.h"
inline __device__ int cascaded_grid_idx_at(
const float x, const float y, const float z,
const int resx, const int resy, const int resz,
const float* aabb
) {
int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resx);
int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resy);
int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resz);
ix = __clamp(ix, 0, resx-1);
iy = __clamp(iy, 0, resy-1);
iz = __clamp(iz, 0, resz-1);
int idx = ix * resy * resz + iy * resz + iz;
return idx;
}
inline __device__ bool grid_occupied_at(
const float x, const float y, const float z,
const int resx, const int resy, const int resz,
const float* aabb, const bool* occ_binary
) {
if (x <= aabb[0] || x >= aabb[3] || y <= aabb[1] || y >= aabb[4] || z <= aabb[2] || z >= aabb[5]) {
return false;
}
int idx = cascaded_grid_idx_at(x, y, z, resx, resy, resz, aabb);
return occ_binary[idx];
}
inline __device__ float distance_to_next_voxel(
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int resx, const int resy, const int resz,
const float* aabb
) { // dda like step
// TODO: this is ugly -- optimize this.
float _x = ((x - aabb[0]) / (aabb[3] - aabb[0])) * resx;
float _y = ((y - aabb[1]) / (aabb[4] - aabb[1])) * resy;
float _z = ((z - aabb[2]) / (aabb[5] - aabb[2])) * resz;
float tx = ((floorf(_x + 0.5f + 0.5f * __sign(dir_x)) - _x) * idir_x) / resx * (aabb[3] - aabb[0]);
float ty = ((floorf(_y + 0.5f + 0.5f * __sign(dir_y)) - _y) * idir_y) / resy * (aabb[4] - aabb[1]);
float tz = ((floorf(_z + 0.5f + 0.5f * __sign(dir_z)) - _z) * idir_z) / resz * (aabb[5] - aabb[2]);
float t = min(min(tx, ty), tz);
return fmaxf(t, 0.0f);
}
inline __device__ float advance_to_next_voxel(
float t,
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int resx, const int resy, const int resz, const float* aabb,
float dt_min) {
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
x, y, z,
dir_x, dir_y, dir_z,
idir_x, idir_y, idir_z,
resx, resy, resz, aabb
);
do {
t += dt_min;
} while (t < t_target);
return t;
}
__global__ void marching_steps_kernel(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
const float* rays_d, // shape (n_rays, 3)
const float* t_min, // shape (n_rays,)
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const float dt,
// outputs
int* num_steps
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
num_steps += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
int j = 0;
float t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, aabb, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
if (j == 0) return;
num_steps[0] = j;
return;
}
__global__ void marching_forward_kernel(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
const float* rays_d, // shape (n_rays, 3)
const float* t_min, // shape (n_rays,)
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const float dt,
const int* packed_info,
// frustrum outputs
float* frustum_starts,
float* frustum_ends
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
// locate
frustum_starts += base;
frustum_ends += base;
int j = 0;
float t0 = near;
float t1 = t0 + dt;
float t_mid = (t0 + t1) / 2.;
while (t_mid < far) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
frustum_starts[j] = t0;
frustum_ends[j] = t1;
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, aabb, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
if (j != steps) {
printf("WTF %d v.s. %d\n", j, steps);
}
return;
}
__global__ void ray_indices_kernel(
// input
const int n_rays,
const int* packed_info,
// output
int* ray_indices
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
ray_indices += base;
for (int j = 0; j < steps; ++j) {
ray_indices[j] = i;
}
}
__global__ void occ_query_kernel(
// rays info
const uint32_t n_samples,
const float* samples, // shape (n_samples, 3)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// outputs
bool* occs
) {
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
occs[0] = grid_occupied_at(
samples[0], samples[1], samples[2],
resx, resy, resz, aabb, occ_binary
);
return;
}
std::vector<torch::Tensor> volumetric_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const float dt
) {
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(aabb);
CHECK_INPUT(occ_binary);
const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::zeros(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
marching_steps_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
dt,
// outputs
num_steps.data_ptr<int>()
);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// std::cout << "num_steps" << num_steps.dtype() << std::endl;
// std::cout << "cum_steps" << cum_steps.dtype() << std::endl;
// std::cout << "packed_info" << packed_info.dtype() << std::endl;
// output frustum samples
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
marching_forward_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
dt,
packed_info.data_ptr<int>(),
// outputs
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
return {packed_info, frustum_starts, frustum_ends};
}
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
const int n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::zeros(
{n_samples}, packed_info.options().dtype(torch::kInt32));
ray_indices_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>()
);
return ray_indices;
}
torch::Tensor query_occ(
const torch::Tensor samples,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary
) {
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::zeros(
{n_samples}, samples.options().dtype(torch::kBool));
occ_query_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// outputs
occs.data_ptr<bool>()
);
return occs;
}
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void volumetric_rendering_steps_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// output: should be all zero (false) initialized
int* num_steps,
bool* selector
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
sigmas += base;
num_steps += i;
selector += base;
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
int j = 0;
for (; j < steps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
T *= (1.f - alpha);
selector[j] = true;
}
num_steps[0] = j;
return;
}
template <typename scalar_t>
__global__ void volumetric_rendering_weights_forward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights // output
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < steps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
weights[j] = weight;
T *= (1.f - alpha);
}
}
template <typename scalar_t>
__global__ void volumetric_rendering_weights_backward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* weights, // forward output
const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
scalar_t accum = 0;
for (int j = 0; j < steps; ++j) {
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < steps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = delta * (grad_weights[j] * T - accum);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
}
std::vector<torch::Tensor> volumetric_rendering_steps(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::zeros({n_rays}, packed_info.options());
torch::Tensor selector = torch::zeros({n_samples}, packed_info.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_marching_steps",
([&]
{ volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
num_steps.data_ptr<int>(),
selector.data_ptr<bool>()
);
}));
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor compact_packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
return {compact_packed_info, selector};
}
torch::Tensor volumetric_rendering_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_rendering_weights_forward",
([&]
{ volumetric_rendering_weights_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>()
);
}));
return weights;
}
torch::Tensor volumetric_rendering_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_rendering_weights_backward",
([&]
{ volumetric_rendering_weights_backward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
grad_sigmas.data_ptr<scalar_t>()
);
}));
return grad_sigmas;
}
""" Occupancy field for accelerating volumetric rendering. """ from typing import Callable, List, Union
from typing import Callable, List, Tuple, Union
import torch import torch
from torch import nn import torch.nn as nn
from .contraction import ContractionType, contract_inv
# TODO: add this to the dependency
# from torch_scatter import scatter_max # from torch_scatter import scatter_max
def meshgrid3d( class Grid(nn.Module):
res: List[int], device: Union[torch.device, str] = "cpu" """An abstract Grid class.
) -> torch.Tensor:
"""Create 3D grid coordinates.
Args: The grid is used as a cache of the 3D space to indicate whether each voxel
res: resolutions for {x, y, z} dimensions. area is important or not for the differentiable rendering process. The
ray marching function (see :func:`nerfacc.ray_marching`) would use the
grid to skip the unimportant voxel areas.
Returns: To work with :func:`nerfacc.ray_marching`, three attributes must exist:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
- :attr:`roi_aabb`: The axis-aligned bounding box of the region of interest.
- :attr:`binary`: A 3D binarized tensor of shape {resx, resy, resz}, \
with torch.bool data type.
- :attr:`contraction_type`: The contraction type of the grid, indicating how \
the 3D space is mapped to the grid.
""" """
assert len(res) == 3
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
def __init__(self, *args, **kwargs):
super().__init__()
self._dummy = torch.nn.Parameter(torch.empty(0))
class OccupancyField(nn.Module): @property
"""Occupancy Field that supports EMA updates. Both 2D and 3D are supported. def device(self) -> torch.device:
return self._dummy.device
Note: @property
Make sure the arguemnts match with the ``num_dim`` -- Either 2D or 3D. def roi_aabb(self) -> torch.Tensor:
"""The axis-aligned bounding box of the region of interest.
Its is a shape (6,) tensor in the format of {minx, miny, minz, maxx, maxy, maxz}.
"""
if hasattr(self, "_roi_aabb"):
return getattr(self, "_roi_aabb")
else:
raise NotImplementedError("please set an attribute named _roi_aabb")
@property
def binary(self) -> torch.Tensor:
"""A 3D binarized tensor with torch.bool data type.
The tensor is of shape (resx, resy, resz), in which each boolen value
represents whether the corresponding voxel should be kept or not.
"""
if hasattr(self, "_binary"):
return getattr(self, "_binary")
else:
raise NotImplementedError("please set an attribute named _binary")
@property
def contraction_type(self) -> ContractionType:
"""The contraction type of the grid.
The contraction type is an indicator of how the 3D space is contracted
to this voxel grid. See :class:`nerfacc.ContractionType` for more details.
"""
if hasattr(self, "_contraction_type"):
return getattr(self, "_contraction_type")
else:
raise NotImplementedError(
"please set an attribute named _contraction_type"
)
class OccupancyGrid(Grid):
"""Occupancy grid: whether each voxel area is occupied or not.
Args: Args:
occ_eval_fn: A Callable function that takes in the un-normalized points x, roi_aabb: The axis-aligned bounding box of the region of interest. Useful for mapping
with shape of (N, 2) or (N, 3) (depends on ``num_dim``), the 3D space to the grid.
and outputs the occupancy of those points with shape of (N, 1). resolution: The resolution of the grid. If an integer is given, the grid is assumed to
aabb: Scene bounding box. If ``num_dim=2`` it should be {min_x, min_y,max_x, max_y}. be a cube. Otherwise, a list or a tensor of shape (3,) is expected. Default: 128.
If ``num_dim=3`` it should be {min_x, min_y, min_z, max_x, max_y, max_z}. contraction_type: The contraction type of the grid. See :class:`nerfacc.ContractionType`
resolution: The field resolution. It can either be a int of a list of ints for more details. Default: :attr:`nerfacc.ContractionType.AABB`.
to specify resolution on each dimension. If ``num_dim=2`` it is for {res_x, res_y}.
If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3.
Attributes:
aabb: Scene bounding box.
occ_grid: The occupancy grid. It is a tensor of shape (num_cells,).
occ_grid_binary: The binary occupancy grid. It is a tensor of shape (num_cells,).
grid_coords: The grid coordinates. It is a tensor of shape (num_cells, num_dim).
grid_indices: The grid indices. It is a tensor of shape (num_cells,).
""" """
aabb: torch.Tensor NUM_DIM: int = 3
occ_grid: torch.Tensor
occ_grid_binary: torch.Tensor
grid_coords: torch.Tensor
grid_indices: torch.Tensor
def __init__( def __init__(
self, self,
occ_eval_fn: Callable, roi_aabb: Union[List[int], torch.Tensor],
aabb: Union[torch.Tensor, List[float]], resolution: Union[int, List[int], torch.Tensor] = 128,
resolution: Union[int, List[int]] = 128, contraction_type: ContractionType = ContractionType.AABB,
num_dim: int = 3,
) -> None: ) -> None:
super().__init__() super().__init__()
self.occ_eval_fn = occ_eval_fn if isinstance(resolution, int):
if not isinstance(aabb, torch.Tensor): resolution = [resolution] * self.NUM_DIM
aabb = torch.tensor(aabb, dtype=torch.float32) if isinstance(resolution, (list, tuple)):
if not isinstance(resolution, (list, tuple)): resolution = torch.tensor(resolution, dtype=torch.int32)
resolution = [resolution] * num_dim assert isinstance(
assert num_dim in [2, 3], "Currently only supports 2D or 3D field." resolution, torch.Tensor
assert aabb.shape == ( ), f"Invalid type: {type(resolution)}"
num_dim * 2, assert resolution.shape == (
), f"shape of aabb ({aabb.shape}) should be num_dim * 2 ({num_dim * 2})." self.NUM_DIM,
assert ( ), f"Invalid shape: {resolution.shape}"
len(resolution) == num_dim
), f"length of resolution ({len(resolution)}) should be num_dim ({num_dim})." if isinstance(roi_aabb, (list, tuple)):
roi_aabb = torch.tensor(roi_aabb, dtype=torch.float32)
self.register_buffer("aabb", aabb) assert isinstance(
self.resolution = resolution roi_aabb, torch.Tensor
self.register_buffer("resolution_tensor", torch.tensor(resolution)) ), f"Invalid type: {type(roi_aabb)}"
self.num_dim = num_dim assert roi_aabb.shape == torch.Size(
self.num_cells = int(torch.tensor(resolution).prod().item()) [self.NUM_DIM * 2]
), f"Invalid shape: {roi_aabb.shape}"
# Stores cell occupancy values ranged in [0, 1].
occ_grid = torch.zeros(self.num_cells) # total number of voxels
self.register_buffer("occ_grid", occ_grid) self.num_cells = int(resolution.prod().item())
occ_grid_binary = torch.zeros(self.num_cells, dtype=torch.bool)
self.register_buffer("occ_grid_binary", occ_grid_binary) # required attributes
self.register_buffer("_roi_aabb", roi_aabb)
self.register_buffer(
"_binary", torch.zeros(resolution.tolist(), dtype=torch.bool)
)
self._contraction_type = contraction_type
# helper attributes
self.register_buffer("resolution", resolution)
self.register_buffer("occs", torch.zeros(self.num_cells))
# Grid coords & indices # Grid coords & indices
grid_coords = meshgrid3d(self.resolution).reshape( grid_coords = _meshgrid3d(resolution).reshape(
self.num_cells, self.num_dim self.num_cells, self.NUM_DIM
) )
self.register_buffer("grid_coords", grid_coords) self.register_buffer("grid_coords", grid_coords)
grid_indices = torch.arange(self.num_cells) grid_indices = torch.arange(self.num_cells)
...@@ -116,13 +143,14 @@ class OccupancyField(nn.Module): ...@@ -116,13 +143,14 @@ class OccupancyField(nn.Module):
@torch.no_grad() @torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor: def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor:
"""Samples both n uniform and occupied cells.""" """Samples both n uniform and occupied cells."""
device = self.occ_grid.device uniform_indices = torch.randint(
self.num_cells, (n,), device=self.device
uniform_indices = torch.randint(self.num_cells, (n,), device=device) )
occupied_indices = torch.nonzero(self._binary.flatten())[:, 0]
occupied_indices = torch.nonzero(self.occ_grid_binary)[:, 0]
if n < len(occupied_indices): if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=device) selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector] occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0) indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices return indices
...@@ -131,6 +159,7 @@ class OccupancyField(nn.Module): ...@@ -131,6 +159,7 @@ class OccupancyField(nn.Module):
def _update( def _update(
self, self,
step: int, step: int,
occ_eval_fn: Callable,
occ_thre: float = 0.01, occ_thre: float = 0.01,
ema_decay: float = 0.95, ema_decay: float = 0.95,
warmup_steps: int = 256, warmup_steps: int = 256,
...@@ -147,92 +176,47 @@ class OccupancyField(nn.Module): ...@@ -147,92 +176,47 @@ class OccupancyField(nn.Module):
grid_coords = self.grid_coords[indices] grid_coords = self.grid_coords[indices]
x = ( x = (
grid_coords + torch.rand_like(grid_coords, dtype=torch.float32) grid_coords + torch.rand_like(grid_coords, dtype=torch.float32)
) / self.resolution_tensor ) / self.resolution
bb_min, bb_max = torch.split( # voxel coordinates [0, 1]^3 -> world
self.aabb, [self.num_dim, self.num_dim], dim=0 x = contract_inv(
x,
roi=self._roi_aabb,
type=self._contraction_type,
) )
x = x * (bb_max - bb_min) + bb_min occ = occ_eval_fn(x).squeeze(-1)
occ = self.occ_eval_fn(x).squeeze(-1)
# ema update # ema update
self.occ_grid[indices] = torch.maximum( self.occs[indices] = torch.maximum(self.occs[indices] * ema_decay, occ)
self.occ_grid[indices] * ema_decay, occ
)
# suppose to use scatter max but emperically it is almost the same. # suppose to use scatter max but emperically it is almost the same.
# self.occ_grid, _ = scatter_max( # self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occ_grid * ema_decay # occ, indices, dim=0, out=self.occs * ema_decay
# ) # )
self.occ_grid_binary = self.occ_grid > torch.clamp( self._binary = (
self.occ_grid.mean(), max=occ_thre self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
) ).reshape(self._binary.shape)
@torch.no_grad()
def query_occ(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the occupancy, given samples.
Args:
x: Samples with shape (..., 2) or (..., 3).
Returns:
float and binary occupancy values with shape (...) respectively.
"""
assert (
x.shape[-1] == self.num_dim
), "The samples are not drawn from a proper space!"
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
bb_min, bb_max = torch.split(
self.aabb, [self.num_dim, self.num_dim], dim=0
)
x = (x - bb_min) / (bb_max - bb_min)
selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
grid_coords = torch.floor(x * resolution).long()
if self.num_dim == 2:
grid_indices = (
grid_coords[..., 0] * self.resolution[-1] + grid_coords[..., 1]
)
elif self.num_dim == 3:
grid_indices = (
grid_coords[..., 0] * self.resolution[-1] * self.resolution[-2]
+ grid_coords[..., 1] * self.resolution[-1]
+ grid_coords[..., 2]
)
else:
raise NotImplementedError("Currently only supports 2D or 3D field.")
occs = torch.zeros(x.shape[:-1], device=x.device)
occs[selector] = self.occ_grid[grid_indices[selector]]
occs_binary = torch.zeros(
x.shape[:-1], device=x.device, dtype=torch.bool
)
occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
return occs, occs_binary
@torch.no_grad() @torch.no_grad()
def every_n_step( def every_n_step(
self, self,
step: int, step: int,
occ_eval_fn: Callable,
occ_thre: float = 1e-2, occ_thre: float = 1e-2,
ema_decay: float = 0.95, ema_decay: float = 0.95,
warmup_steps: int = 256, warmup_steps: int = 256,
n: int = 16, n: int = 16,
): ) -> None:
"""Update the field every n steps during training. """Update the grid every n steps during training.
This function is designed for training only. If for some reason you want to
manually update the field, please use the ``_update()`` function instead.
Args: Args:
step: Current training step. step: Current training step.
occ_thre: Threshold to binarize the occupancy field. occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
ema_decay: The decay rate for EMA updates. returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. together with 1/4 occupied cells. Default: 256.
n: Update the field every n steps. n: Update the grid every n steps. Default: 16.
Returns:
None
""" """
if not self.training: if not self.training:
raise RuntimeError( raise RuntimeError(
...@@ -243,18 +227,31 @@ class OccupancyField(nn.Module): ...@@ -243,18 +227,31 @@ class OccupancyField(nn.Module):
if step % n == 0 and self.training: if step % n == 0 and self.training:
self._update( self._update(
step=step, step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=occ_thre, occ_thre=occ_thre,
ema_decay=ema_decay, ema_decay=ema_decay,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
) )
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the occupancy, given samples.
Args: def _meshgrid3d(
x: Samples with shape (..., 2) or (..., 3). res: torch.Tensor, device: Union[torch.device, str] = "cpu"
) -> torch.Tensor:
Returns: """Create 3D grid coordinates."""
float and binary occupancy values with shape (...) respectively. assert len(res) == 3
""" res = res.tolist()
return self.query_occ(x) return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
""" Full volumetric rendering pipeline. """
from typing import Callable, List, Optional, Tuple
import torch
from .utils import (
unpack_to_ray_indices,
volumetric_marching,
volumetric_rendering_accumulate,
volumetric_rendering_steps,
volumetric_rendering_weights,
)
def volumetric_rendering_pipeline(
sigma_fn: Callable,
rgb_sigma_fn: Callable,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
scene_resolution: Optional[List[int]] = None,
scene_occ_binary: Optional[torch.Tensor] = None,
render_bkgd: Optional[torch.Tensor] = None,
render_step_size: float = 1e-3,
near_plane: float = 0.0,
stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
"""Differentiable volumetric rendering pipeline.
This function is the integration of those individual functions:
- ray_aabb_intersect
- volumetric_marching
- volumetric_rendering_steps
- volumetric_rendering_weights
- volumetric_rendering_accumulate
Args:
sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation sigma values (N, 1).
rgb_sigma_fn: A function that takes in the {frustum starts (N, 1), frustum ends (N, 1), and
ray indices (N,)} and returns the post-activation rgb values (N, 3) and sigma values (N, 1).
rays_o: The origin of the rays (n_rays, 3).
rays_d: The normalized direction of the rays (n_rays, 3).
scene_aabb: The scene axis-aligned bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
scene_resolution: The scene resolution (3,). Defaults to None.
scene_occ_binary: The scene occupancy binary tensor used to skip samples (n_cells,). Defaults to None.
render_bkgd: The background color (3,). Default: None.
render_step_size: The step size for the volumetric rendering. Default: 1e-3.
near_plane: The near plane for the volumetric rendering. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
Returns:
Ray colors (n_rays, 3), and opacities (n_rays, 1), the number of marching steps, and the number of rendering steps.
"""
n_rays = rays_o.shape[0]
if scene_occ_binary is None:
scene_occ_binary = torch.ones(
(1),
dtype=torch.bool,
device=rays_o.device,
)
scene_resolution = [1, 1, 1]
if scene_resolution is None:
assert scene_occ_binary is not None and scene_occ_binary.dim() == 3
scene_resolution = scene_occ_binary.shape
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
scene_occ_binary = scene_occ_binary.contiguous()
with torch.no_grad():
# Ray marching and occupancy check.
assert scene_resolution is not None
packed_info, frustum_starts, frustum_ends = volumetric_marching(
rays_o,
rays_d,
aabb=scene_aabb,
scene_resolution=scene_resolution,
scene_occ_binary=scene_occ_binary,
render_step_size=render_step_size,
near_plane=near_plane,
stratified=stratified,
)
n_marching_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma without gradients
sigmas = sigma_fn(frustum_starts, frustum_ends, ray_indices)
# Ray marching and rendering check.
packed_info, frustum_starts, frustum_ends = volumetric_rendering_steps(
packed_info,
sigmas,
frustum_starts,
frustum_ends,
)
n_rendering_samples = frustum_starts.shape[0]
ray_indices = unpack_to_ray_indices(packed_info)
# Query sigma and color with gradients
rgbs, sigmas = rgb_sigma_fn(frustum_starts, frustum_ends, ray_indices)
assert rgbs.shape[-1] == 3, f"rgbs must have 3 channels, got {rgbs.shape}"
assert (
sigmas.shape[-1] == 1
), f"sigmas must have 1 channel, got {sigmas.shape}"
# Rendering: compute weights and ray indices.
weights = volumetric_rendering_weights(
packed_info, sigmas, frustum_starts, frustum_ends
)
# Rendering: accumulate rgbs and opacities along the rays.
colors = volumetric_rendering_accumulate(
weights, ray_indices, values=rgbs, n_rays=n_rays
)
opacities = volumetric_rendering_accumulate(
weights, ray_indices, values=None, n_rays=n_rays
)
# depths = volumetric_rendering_accumulate(
# weights,
# ray_indices,
# values=(frustum_starts + frustum_ends) / 2.0,
# n_rays=n_rays,
# )
if render_bkgd is not None:
render_bkgd = render_bkgd.contiguous()
colors = colors + render_bkgd * (1.0 - opacities)
return colors, opacities, n_marching_samples, n_rendering_samples
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment