Unverified Commit 10315043 authored by Youtian Lin's avatar Youtian Lin Committed by GitHub
Browse files

Accelerate Instant-NGP inference (#197)



* add mark_invisible_cells in occ_grid

add test mode for traverse_grids

* add data type to mark_invisible_cells

* add test for mark_invisible_cells & test mode traverse_grids

* upd comments

* ndr trial

* merge traverse_grids with traverse_grids_test in C

* fix format

* Revert "fix format"

This reverts commit 6233fc4e3a4ea4643ace6a96d28ad32c6a3a444f.

* revert benchmarks changes

* remove MLP updates

* Revert "remove MLP updates"

This reverts commit c37d199463e1d1fd2d28e061137391fe1b01b537.

* revert benchmarks changes

* add assert in traverse_grids

* Revert "add assert in traverse_grids"

This reverts commit c93eaad20b7e53f8dca49a3f9af97b549b7092fc.

* revert benchmarks changes

* reduce mem for traverse grid with over_allocate=True

* cleanup doc

* final cleanup with mark_invisible_cells

* ndr trial

* fix occ grid invisible cell filtering

---------
Co-authored-by: default avatarRuilong Li <ruilongli94@gmail.com>
parent 0be61b2a
......@@ -27,11 +27,50 @@ inline __device__ float _calc_dt(
return clamp(t * cone_angle, dt_min, dt_max);
}
/* Ray traversal within multiple voxel grids.
About rays:
Each ray is defined by its origin (rays_o) and unit direction (rays_d). We also allows
a optional boolen ray mask (rays_mask) to indicate whether we want to skip some rays.
About voxel grids:
We support ray traversal through one or more voxel grids (n_grids). Each grid is defined
by an axis-aligned AABB (aabbs), and a binary occupancy grid (binaries) with resolution of
{resx, resy, resz}. Currently, we assume all grids have the same resolution. Note the ordering
of the grids is important when there are overlapping grids, because we assume the grid in front
has higher priority when examing occupancy status (e.g., the first grid's occupancy status
will overwrite the second grid's occupancy status if they overlap).
About ray grid intersections:
We require the ray grid intersections to be precomputed and sorted. Specifically, if hit, each
ray-grid pair has two intersections, one for entering the grid and one for leaving the grid.
For multiple grids, there are in total 2 * n_grids intersections for each ray. The intersections
are sorted by the distance to the ray origin (t_sorted). We take a boolen array (hits) to indicate
whether each ray-grid pair is hit. We also need a int64 array (t_indices) to indicate the grid id
(0-index) for each intersection.
About ray traversal:
The ray is traversed through the grids in the order of the sorted intersections. We allows pre-ray
near and far planes (near_planes, far_planes) to be specified. Early termination can be controlled by
setting the maximum traverse steps via traverse_steps_limit. We also allow an optional step size
(step_size) to be specified. If step_size <= 0.0, we will record the steps of the ray pass through
each voxel cell. Otherwise, we will use the step_size to march through the grids. When step_size > 0.0,
we also allow a cone angle (cone_angle) to be provides, to linearly increase the step size as the ray
goes further away from the origin (see _calc_dt()). cone_angle should be always >= 0.0, and 0.0
means uniform marching with step_size.
About outputs:
The traversal intervals and samples are stored in `intervals` and `samples` respectively. Additionally,
we also return where the traversal actually terminates (terminate_planes). This is useful when
traverse_steps_limit is set (traverse_steps_limit > 0) as the ray may not reach the far plane or the
boundary of the grids.
*/
__global__ void traverse_grids_kernel(
// rays
int32_t n_rays,
float *rays_o, // [n_rays, 3]
float *rays_d, // [n_rays, 3]
bool *rays_mask, // [n_rays]
// grids
int32_t n_grids,
int3 resolution,
......@@ -42,20 +81,24 @@ __global__ void traverse_grids_kernel(
float *t_sorted, // [n_rays, n_grids * 2]
int64_t *t_indices, // [n_rays, n_grids * 2]
// options
float *near_planes,
float *far_planes,
float *near_planes, // [n_rays]
float *far_planes, // [n_rays]
float step_size,
float cone_angle,
int32_t traverse_steps_limit,
// outputs
bool first_pass,
PackedRaySegmentsSpec intervals,
PackedRaySegmentsSpec samples)
PackedRaySegmentsSpec samples,
float *terminate_planes)
{
float eps = 1e-6f;
// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)
{
if (rays_mask != nullptr && !rays_mask[tid]) continue;
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
......@@ -138,7 +181,7 @@ __global__ void traverse_grids_kernel(
// );
const int3 overflow_index = final_index + step_index;
while (true) {
while (traverse_steps_limit <= 0 || n_samples < traverse_steps_limit) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
t_traverse = fminf(t_traverse, this_tmax);
int64_t cell_id = (
......@@ -162,7 +205,7 @@ __global__ void traverse_grids_kernel(
continuous = false;
} else {
// this cell is not empty, so we need to traverse it.
while (true) {
while (traverse_steps_limit <= 0 || n_samples < traverse_steps_limit) {
float t_next;
if (step_size <= 0.0f) {
t_next = t_traverse;
......@@ -207,10 +250,11 @@ __global__ void traverse_grids_kernel(
int64_t idx = chunk_start_bin + n_samples;
samples.vals[idx] = (t_next + t_last) * 0.5f;
samples.ray_indices[idx] = tid;
samples.is_valid[idx] = true;
}
n_samples++;
}
n_samples++;
continuous = true;
t_last = t_next;
if (t_next >= t_traverse) break;
......@@ -227,17 +271,16 @@ __global__ void traverse_grids_kernel(
}
}
}
if (first_pass) {
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
samples.chunk_cnts[tid] = n_samples;
}
if (terminate_planes != nullptr)
terminate_planes[tid] = t_last;
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
samples.chunk_cnts[tid] = n_samples;
}
}
__global__ void ray_aabb_intersect_kernel(
const int32_t n_rays, float *rays_o, float *rays_d, float near, float far,
const int32_t n_aabbs, float *aabbs,
......@@ -274,16 +317,17 @@ __global__ void ray_aabb_intersect_kernel(
} // namespace
std::vector<RaySegmentsSpec> traverse_grids(
std::tuple<RaySegmentsSpec, RaySegmentsSpec, torch::Tensor> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor rays_mask, // [n_rays]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
......@@ -291,9 +335,15 @@ std::vector<RaySegmentsSpec> traverse_grids(
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples)
const bool compute_samples,
const bool compute_terminate_planes,
const int32_t traverse_steps_limit, // <= 0 means no limit
const bool over_allocate) // over allocate the memory for intervals and samples
{
DEVICE_GUARD(rays_o);
if (over_allocate) {
TORCH_CHECK(traverse_steps_limit > 0, "traverse_steps_limit must be > 0 when over_allocate is true");
}
int32_t n_rays = rays_o.size(0);
int32_t n_grids = binaries.size(0);
......@@ -305,80 +355,122 @@ std::vector<RaySegmentsSpec> traverse_grids(
dim3 threads = dim3(min(max_threads, n_rays));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.x)));
// Sort the intersections. [n_rays, n_grids * 2]
torch::Tensor t_sorted, t_indices;
if (n_grids > 1) {
std::tie(t_sorted, t_indices) = torch::sort(torch::cat({t_mins, t_maxs}, -1), -1);
}
else {
t_sorted = torch::cat({t_mins, t_maxs}, -1);
t_indices = torch::arange(
0, n_grids * 2, t_mins.options().dtype(torch::kLong)
).expand({n_rays, n_grids * 2}).contiguous();
}
// outputs
RaySegmentsSpec intervals, samples;
torch::Tensor terminate_planes;
if (compute_terminate_planes)
terminate_planes = torch::empty({n_rays}, rays_o.options());
if (over_allocate) {
// over allocate the memory so that we can traverse the grids in a single pass.
if (compute_intervals) {
intervals.chunk_cnts = torch::full({n_rays}, traverse_steps_limit * 2, rays_o.options().dtype(torch::kLong)) * rays_mask;
intervals.memalloc_data_from_chunk(true, true);
}
if (compute_samples) {
samples.chunk_cnts = torch::full({n_rays}, traverse_steps_limit, rays_o.options().dtype(torch::kLong)) * rays_mask;
samples.memalloc_data_from_chunk(false, true, true);
}
// first pass to count the number of segments along each ray.
if (compute_intervals)
intervals.memalloc_cnts(n_rays, rays_o.options(), false);
if (compute_samples)
samples.memalloc_cnts(n_rays, rays_o.options(), false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
rays_mask.data_ptr<bool>(), // [n_rays]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
compute_terminate_planes ? terminate_planes.data_ptr<float>() : nullptr);
// update the chunk starts with the actual chunk_cnts from traversal.
intervals.compute_chunk_start();
samples.compute_chunk_start();
} else {
// To allocate the accurate memory we need to traverse the grids twice.
// The first pass is to count the number of segments along each ray.
// The second pass is to fill the segments.
if (compute_intervals)
intervals.chunk_cnts = torch::empty({n_rays}, rays_o.options().dtype(torch::kLong));
if (compute_samples)
samples.chunk_cnts = torch::empty({n_rays}, rays_o.options().dtype(torch::kLong));
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
nullptr, /* rays_mask */
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
nullptr); /* terminate_planes */
// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data_from_chunk(true, true);
if (compute_samples)
samples.memalloc_data_from_chunk(false, false, true);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
nullptr, /* rays_mask */
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
compute_terminate_planes ? terminate_planes.data_ptr<float>() : nullptr);
}
// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data(true, true);
if (compute_samples)
samples.memalloc_data(false, false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
return {intervals, samples};
return {intervals, samples, terminate_planes};
}
......
......@@ -11,6 +11,7 @@ struct RaySegmentsSpec {
torch::Tensor ray_indices; // [n_edges]
torch::Tensor is_left; // [n_edges] have n_bins true values
torch::Tensor is_right; // [n_edges] have n_bins true values
torch::Tensor is_valid; // [n_edges] have n_bins true values
inline void check() {
CHECK_INPUT(vals);
......@@ -42,6 +43,11 @@ struct RaySegmentsSpec {
TORCH_CHECK(is_right.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_right.numel());
}
if (is_valid.defined()) {
CHECK_INPUT(is_valid);
TORCH_CHECK(is_valid.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_valid.numel());
}
}
inline void memalloc_cnts(int32_t n_rays, at::TensorOptions options, bool zero_init = true) {
......@@ -53,30 +59,49 @@ struct RaySegmentsSpec {
}
}
inline int64_t memalloc_data(bool alloc_masks = true, bool zero_init = true) {
inline void memalloc_data(int32_t size, bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());
TORCH_CHECK(!vals.defined());
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();
chunk_starts = cumsum - chunk_cnts;
if (zero_init) {
vals = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::zeros({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
} else {
vals = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::empty({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
}
}
if (alloc_valid) {
is_valid = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
}
inline int64_t memalloc_data_from_chunk(bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();
chunk_starts = cumsum - chunk_cnts;
memalloc_data(n_edges, alloc_masks, zero_init, alloc_valid);
return 1;
}
// compute the chunk_start from chunk_cnts
inline int64_t compute_chunk_start() {
TORCH_CHECK(chunk_cnts.defined());
// TORCH_CHECK(!chunk_starts.defined());
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
chunk_starts = cumsum - chunk_cnts;
return 1;
}
};
\ No newline at end of file
......@@ -17,6 +17,7 @@ struct PackedRaySegmentsSpec {
ray_indices(spec.ray_indices.defined() ? spec.ray_indices.data_ptr<int64_t>() : nullptr),
is_left(spec.is_left.defined() ? spec.is_left.data_ptr<bool>() : nullptr),
is_right(spec.is_right.defined() ? spec.is_right.data_ptr<bool>() : nullptr),
is_valid(spec.is_valid.defined() ? spec.is_valid.data_ptr<bool>() : nullptr),
// for dimensions
n_edges(spec.vals.defined() ? spec.vals.numel() : 0),
n_rays(spec.chunk_cnts.defined() ? spec.chunk_cnts.size(0) : 0), // for flattened tensor
......@@ -31,6 +32,7 @@ struct PackedRaySegmentsSpec {
int64_t* ray_indices;
bool* is_left;
bool* is_right;
bool* is_valid;
int64_t n_edges;
int32_t n_rays;
......
......@@ -46,16 +46,17 @@ std::vector<torch::Tensor> ray_aabb_intersect(
const float near_plane,
const float far_plane,
const float miss_value);
std::vector<RaySegmentsSpec> traverse_grids(
std::tuple<RaySegmentsSpec, RaySegmentsSpec, torch::Tensor> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor rays_mask, // [n_rays]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
......@@ -63,7 +64,10 @@ std::vector<RaySegmentsSpec> traverse_grids(
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples);
const bool compute_samples,
const bool compute_terminate_planes,
const int32_t traverse_steps_limit, // <= 0 means no limit
const bool over_allocate); // over allocate the memory for intervals and samples
// pdf
std::vector<RaySegmentsSpec> importance_sampling(
......@@ -118,6 +122,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("vals", &RaySegmentsSpec::vals)
.def_readwrite("is_left", &RaySegmentsSpec::is_left)
.def_readwrite("is_right", &RaySegmentsSpec::is_right)
.def_readwrite("is_valid", &RaySegmentsSpec::is_valid)
.def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
.def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
.def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
......
......@@ -43,6 +43,7 @@ class RaySamples:
vals: torch.Tensor
packed_info: Optional[torch.Tensor] = None
ray_indices: Optional[torch.Tensor] = None
is_valid: Optional[torch.Tensor] = None
def _to_cpp(self):
"""
......@@ -69,8 +70,16 @@ class RaySamples:
else:
packed_info = None
ray_indices = spec.ray_indices
if spec.is_valid is not None:
is_valid = spec.is_valid
else:
is_valid = None
vals = spec.vals
return cls(
vals=spec.vals, packed_info=packed_info, ray_indices=ray_indices
vals=vals,
packed_info=packed_info,
ray_indices=ray_indices,
is_valid=is_valid,
)
@property
......
......@@ -161,7 +161,7 @@ class OccGridEstimator(AbstractEstimator):
if stratified:
near_planes += torch.rand_like(near_planes) * render_step_size
intervals, samples = traverse_grids(
intervals, samples, _ = traverse_grids(
rays_o,
rays_d,
self.binaries,
......@@ -258,10 +258,89 @@ class OccGridEstimator(AbstractEstimator):
warmup_steps=warmup_steps,
)
# adapted from https://github.com/kwea123/ngp_pl/blob/master/models/networks.py
@torch.no_grad()
def mark_invisible_cells(
self,
K: Tensor,
c2w: Tensor,
width: int,
height: int,
near_plane: float = 0.0,
chunk: int = 32**3,
) -> None:
"""Mark the cells that aren't covered by the cameras with density -1.
Should only be executed once before training starts.
Args:
K: Camera intrinsics of shape (N, 3, 3) or (1, 3, 3).
c2w: Camera to world poses of shape (N, 3, 4) or (N, 4, 4).
width: Image width in pixels
height: Image height in pixels
near_plane: Near plane distance
chunk: The chunk size to split the cells (to avoid OOM)
"""
assert K.dim() == 3 and K.shape[1:] == (3, 3)
assert c2w.dim() == 3 and (
c2w.shape[1:] == (3, 4) or c2w.shape[1:] == (4, 4)
)
assert K.shape[0] == c2w.shape[0] or K.shape[0] == 1
N_cams = c2w.shape[0]
w2c_R = c2w[:, :3, :3].transpose(2, 1) # (N_cams, 3, 3)
w2c_T = -w2c_R @ c2w[:, :3, 3:] # (N_cams, 3, 1)
lvl_indices = self._get_all_cells()
for lvl, indices in enumerate(lvl_indices):
grid_coords = self.grid_coords[indices]
for i in range(0, len(indices), chunk):
x = grid_coords[i : i + chunk] / (self.resolution - 1)
indices_chunk = indices[i : i + chunk]
# voxel coordinates [0, 1]^3 -> world
xyzs_w = (
self.aabbs[lvl, :3]
+ x * (self.aabbs[lvl, 3:] - self.aabbs[lvl, :3])
).T
xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk)
uvd = K @ xyzs_c # (N_cams, 3, chunk)
uv = uvd[:, :2] / uvd[:, 2:] # (N_cams, 2, chunk)
in_image = (
(uvd[:, 2] >= 0)
& (uv[:, 0] >= 0)
& (uv[:, 0] < width)
& (uv[:, 1] >= 0)
& (uv[:, 1] < height)
)
covered_by_cam = (
uvd[:, 2] >= near_plane
) & in_image # (N_cams, chunk)
# if the cell is visible by at least one camera
count = covered_by_cam.sum(0) / N_cams
too_near_to_cam = (
uvd[:, 2] < near_plane
) & in_image # (N, chunk)
# if the cell is too close (in front) to any camera
too_near_to_any_cam = too_near_to_cam.any(0)
# a valid cell should be visible by at least one camera and not too close to any camera
valid_mask = (count > 0) & (~too_near_to_any_cam)
cell_ids_base = lvl * self.cells_per_lvl
self.occs[cell_ids_base + indices_chunk] = torch.where(
valid_mask, 0.0, -1.0
)
@torch.no_grad()
def _get_all_cells(self) -> List[Tensor]:
"""Returns all cells of the grid."""
return [self.grid_indices] * self.levels
lvl_indices = []
for lvl in range(self.levels):
# filter out the cells with -1 density (non-visible to any camera)
cell_ids = lvl * self.cells_per_lvl + self.grid_indices
indices = self.grid_indices[self.occs[cell_ids] >= 0.0]
lvl_indices.append(indices)
return lvl_indices
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
......@@ -271,6 +350,9 @@ class OccGridEstimator(AbstractEstimator):
uniform_indices = torch.randint(
self.cells_per_lvl, (n,), device=self.device
)
# filter out the cells with -1 density (non-visible to any camera)
cell_ids = lvl * self.cells_per_lvl + uniform_indices
uniform_indices = uniform_indices[self.occs[cell_ids] >= 0.0]
occupied_indices = torch.nonzero(self.binaries[lvl].flatten())[:, 0]
if n < len(occupied_indices):
selector = torch.randint(
......@@ -318,9 +400,8 @@ class OccGridEstimator(AbstractEstimator):
# self.occs, _ = scatter_max(
# occ, indices, dim=0, out=self.occs * ema_decay
# )
self.binaries = (
self.occs > torch.clamp(self.occs.mean(), max=occ_thre)
).view(self.binaries.shape)
thre = torch.clamp(self.occs[self.occs >= 0].mean(), max=occ_thre)
self.binaries = (self.occs > thre).view(self.binaries.shape)
def _meshgrid3d(
......
......@@ -103,7 +103,14 @@ def traverse_grids(
far_planes: Optional[Tensor] = None, # [n_rays]
step_size: Optional[float] = 1e-3,
cone_angle: Optional[float] = 0.0,
) -> Tuple[RayIntervals, RaySamples]:
traverse_steps_limit: Optional[int] = None,
over_allocate: Optional[bool] = False,
rays_mask: Optional[Tensor] = None, # [n_rays]
# pre-compute intersections
t_sorted: Optional[Tensor] = None, # [n_rays, n_grids]
t_indices: Optional[Tensor] = None, # [n_rays, n_grids]
hits: Optional[Tensor] = None, # [n_rays, n_grids]
) -> Tuple[RayIntervals, RaySamples, Tensor]:
"""Ray Traversal within Multiple Grids.
Note:
......@@ -119,29 +126,53 @@ def traverse_grids(
step_size: Optional. Step size for ray traversal. Default to 1e-3.
cone_angle: Optional. Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
traverse_steps_limit: Optional. Maximum number of samples per ray.
over_allocate: Optional. Whether to over-allocate the memory for the outputs.
rays_mask: Optional. (n_rays,) Skip some rays if given.
t_sorted: Optional. (n_rays, n_grids) Pre-computed sorted t values for each ray-grid pair. Default to None.
t_indices: Optional. (n_rays, n_grids) Pre-computed sorted t indices for each ray-grid pair. Default to None.
hits: Optional. (n_rays, n_grids) Pre-computed hit flags for each ray-grid pair. Default to None.
Returns:
A :class:`RayIntervals` object containing the intervals of the ray traversal, and
a :class:`RaySamples` object containing the samples within each interval.
t :class:`Tensor` of shape (n_rays,) containing the terminated t values for each ray.
"""
# Compute ray aabb intersection for all levels of grid. [n_rays, m]
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, aabbs)
if near_planes is None:
near_planes = torch.zeros_like(rays_o[:, 0])
if far_planes is None:
far_planes = torch.full_like(rays_o[:, 0], float("inf"))
intervals, samples = _C.traverse_grids(
if rays_mask is None:
rays_mask = torch.ones_like(rays_o[:, 0], dtype=torch.bool)
if traverse_steps_limit is None:
traverse_steps_limit = -1
if over_allocate:
assert (
traverse_steps_limit > 0
), "traverse_steps_limit must be set if over_allocate is True."
if t_sorted is None or t_indices is None or hits is None:
# Compute ray aabb intersection for all levels of grid. [n_rays, m]
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, aabbs)
# Sort the t values for each ray. [n_rays, m]
t_sorted, t_indices = torch.sort(
torch.cat([t_mins, t_maxs], dim=-1), dim=-1
)
# Traverse the grids.
intervals, samples, termination_planes = _C.traverse_grids(
# rays
rays_o.contiguous(), # [n_rays, 3]
rays_d.contiguous(), # [n_rays, 3]
rays_mask.contiguous(), # [n_rays]
# grids
binaries.contiguous(), # [m, resx, resy, resz]
aabbs.contiguous(), # [m, 6]
# intersections
t_mins.contiguous(), # [n_rays, m]
t_maxs.contiguous(), # [n_rays, m]
t_sorted.contiguous(), # [n_rays, m]
t_indices.contiguous(), # [n_rays, m]
hits.contiguous(), # [n_rays, m]
# options
near_planes.contiguous(), # [n_rays]
......@@ -150,8 +181,15 @@ def traverse_grids(
cone_angle,
True,
True,
True,
traverse_steps_limit,
over_allocate,
)
return (
RayIntervals._from_cpp(intervals),
RaySamples._from_cpp(samples),
termination_planes,
)
return RayIntervals._from_cpp(intervals), RaySamples._from_cpp(samples)
def _enlarge_aabb(aabb, factor: float) -> Tensor:
......
......@@ -545,3 +545,29 @@ def accumulate_along_rays(
else:
outputs = torch.sum(src, dim=-2)
return outputs
def accumulate_along_rays_(
weights: Tensor,
values: Optional[Tensor] = None,
ray_indices: Optional[Tensor] = None,
outputs: Optional[Tensor] = None,
) -> None:
"""Accumulate volumetric values along the ray.
Inplace version of :func:`accumulate_along_rays`.
"""
if values is None:
src = weights[..., None]
else:
assert values.dim() == weights.dim() + 1
assert weights.shape == values.shape[:-1]
src = weights[..., None] * values
if ray_indices is not None:
assert weights.dim() == 1, "weights must be flattened"
assert (
outputs.dim() == 2 and outputs.shape[-1] == src.shape[-1]
), "outputs must be of shape (n_rays, D)"
outputs.index_add_(0, ray_indices, src)
else:
outputs.add_(src.sum(dim=-2))
......@@ -54,7 +54,7 @@ def test_traverse_grids():
binaries = torch.rand((n_aabbs, 32, 32, 32), device=device) > 0.5
intervals, samples = traverse_grids(rays_o, rays_d, binaries, aabbs)
intervals, samples, _ = traverse_grids(rays_o, rays_d, binaries, aabbs)
ray_indices = samples.ray_indices
t_starts = intervals.vals[intervals.is_left]
......@@ -68,6 +68,69 @@ def test_traverse_grids():
assert selector.all(), selector.float().mean()
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_traverse_grids_test_mode():
from nerfacc.grid import _enlarge_aabb, traverse_grids
from nerfacc.volrend import accumulate_along_rays
torch.manual_seed(42)
n_rays = 10
n_aabbs = 4
rays_mask = torch.ones((n_rays,), device=device, dtype=torch.bool)
rays_o = torch.randn((n_rays, 3), device=device)
rays_d = torch.randn((n_rays, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
base_aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
aabbs = torch.stack(
[_enlarge_aabb(base_aabb, 2**i) for i in range(n_aabbs)]
)
binaries = torch.rand((n_aabbs, 32, 32, 32), device=device) > 0.5
# ref results: train mode
intervals, samples, _ = traverse_grids(rays_o, rays_d, binaries, aabbs)
ray_indices = samples.ray_indices
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
accum_t_starts = accumulate_along_rays(t_starts, None, ray_indices, n_rays)
accum_t_ends = accumulate_along_rays(t_ends, None, ray_indices, n_rays)
# test mode
_accum_t_starts, _accum_t_ends = 0.0, 0.0
_terminate_planes = None
_rays_mask = None
for _ in range(2):
_intervals, _samples, _terminate_planes = traverse_grids(
rays_o,
rays_d,
binaries,
aabbs,
near_planes=_terminate_planes,
traverse_steps_limit=4000,
over_allocate=True,
rays_mask=_rays_mask,
)
# only keep rays that are not terminated (i.e. reach the limit)
_rays_mask = _samples.packed_info[:, 1] == 4000
_ray_indices = _samples.ray_indices[_samples.is_valid]
_t_starts = _intervals.vals[_intervals.is_left]
_t_ends = _intervals.vals[_intervals.is_right]
_accum_t_starts += accumulate_along_rays(
_t_starts, None, _ray_indices, n_rays
)
_accum_t_ends += accumulate_along_rays(
_t_ends, None, _ray_indices, n_rays
)
# there shouldn't be any rays that are not terminated
assert (~_rays_mask).all()
# TODO: figure out where this small diff comes from
assert torch.allclose(_accum_t_starts, accum_t_starts, atol=1e-1)
assert torch.allclose(accum_t_ends, _accum_t_ends, atol=1e-1)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_traverse_grids_with_near_far_planes():
from nerfacc.grid import traverse_grids
......@@ -83,7 +146,7 @@ def test_traverse_grids_with_near_far_planes():
far_planes = torch.tensor([1.5], device=device)
step_size = 0.05
intervals, samples = traverse_grids(
intervals, samples, _ = traverse_grids(
rays_o=rays_o,
rays_d=rays_d,
binaries=binaries,
......@@ -140,8 +203,40 @@ def test_sampling_with_min_max_distances():
assert (t_ends <= (t_max[ray_indices] + render_step_size / 2)).all()
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_mark_invisible_cells():
from nerfacc import OccGridEstimator
levels = 4
resolution = 32
width = 100
height = 100
fx, fy = width, height
cx, cy = width / 2, height / 2
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
grid_estimator = OccGridEstimator(
roi_aabb=aabb, resolution=resolution, levels=levels
).to(device)
K = torch.tensor([[[fx, 0, cx], [0, fy, cy], [0, 0, 1]]], device=device)
pose = torch.tensor(
[[[-1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 2.5]]],
device=device,
)
grid_estimator.mark_invisible_cells(K, pose, width, height)
assert (grid_estimator.occs == -1).sum() == 77660
assert (grid_estimator.occs == 0).sum() == 53412
if __name__ == "__main__":
test_ray_aabb_intersect()
test_traverse_grids()
test_traverse_grids_with_near_far_planes()
test_sampling_with_min_max_distances()
test_mark_invisible_cells()
test_traverse_grids_test_mode()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment