Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
88a6aec6
Commit
88a6aec6
authored
Jan 02, 2024
by
Ruilong Li
Browse files
n3tree
parent
9f90842b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
299 additions
and
0 deletions
+299
-0
nerfacc/estimators/n3tree.py
nerfacc/estimators/n3tree.py
+299
-0
No files found.
nerfacc/estimators/n3tree.py
0 → 100644
View file @
88a6aec6
import
math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
..grid
import
_enlarge_aabb
from
..volrend
import
(
render_visibility_from_alpha
,
render_visibility_from_density
,
)
from
.base
import
AbstractEstimator
from
torch
import
Tensor
try
:
import
svox
except
ImportError
:
raise
ImportError
(
"Please install this forked version of svox: "
"pip install git+https://github.com/liruilong940607/svox.git"
)
class
N3TreeEstimator
(
AbstractEstimator
):
"""Use N3Tree to implement Occupancy Grid.
This allows more flexible topologies than the cascaded grid. However, it is
slower to create samples from the tree than the cascaded grid. By default,
it has the same topology as the cascaded grid but `self.tree` can be
modified to have different topologies.
"""
def
__init__
(
self
,
roi_aabb
:
Union
[
List
[
int
],
Tensor
],
resolution
:
Union
[
int
,
List
[
int
],
Tensor
]
=
128
,
levels
:
int
=
1
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
if
"contraction_type"
in
kwargs
:
raise
ValueError
(
"`contraction_type` is not supported anymore for nerfacc >= 0.4.0."
)
# check the resolution is legal
assert
isinstance
(
resolution
,
int
),
"N3Tree only supports uniform resolution!"
# check the roi_aabb is legal
if
isinstance
(
roi_aabb
,
(
list
,
tuple
)):
roi_aabb
=
torch
.
tensor
(
roi_aabb
,
dtype
=
torch
.
float32
)
assert
isinstance
(
roi_aabb
,
Tensor
),
f
"Invalid type:
{
roi_aabb
}
!"
assert
roi_aabb
.
shape
[
0
]
==
6
,
f
"Invalid shape:
{
roi_aabb
}
!"
roi_aabb
=
roi_aabb
.
cpu
()
# to be compatible with the OccupancyGrid
aabbs
=
torch
.
stack
(
[
_enlarge_aabb
(
roi_aabb
,
2
**
i
)
for
i
in
range
(
levels
)],
dim
=
0
)
self
.
register_buffer
(
"aabbs"
,
aabbs
)
# [n_aabbs, 6]
center
=
(
roi_aabb
[:
3
]
+
roi_aabb
[
3
:])
/
2.0
radius
=
(
roi_aabb
[
3
:]
-
roi_aabb
[:
3
])
/
2.0
*
2
**
(
levels
-
1
)
tree_depth
=
int
(
math
.
log2
(
resolution
))
-
1
self
.
tree
=
svox
.
N3Tree
(
N
=
2
,
data_dim
=
1
,
init_refine
=
tree_depth
,
depth_limit
=
20
,
radius
=
radius
.
tolist
(),
center
=
center
.
tolist
(),
)
_aabbs
=
[
_enlarge_aabb
(
roi_aabb
,
2
**
i
)
for
i
in
range
(
levels
-
1
)]
for
aabb
in
_aabbs
[::
-
1
]:
leaf_c
=
self
.
tree
.
corners
+
self
.
tree
.
lengths
*
0.5
sel
=
((
leaf_c
>
aabb
[:
3
])
&
(
leaf_c
<
aabb
[
3
:])).
all
(
dim
=-
1
)
self
.
tree
[
sel
].
refine
()
# print("tree size", len(self.tree), "at resolution", resolution)
self
.
thresh
=
0.0
@
torch
.
no_grad
()
def
sampling
(
self
,
# rays
rays_o
:
Tensor
,
# [n_rays, 3]
rays_d
:
Tensor
,
# [n_rays, 3]
# sigma/alpha function for skipping invisible space
sigma_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
near_plane
:
float
=
0.0
,
far_plane
:
float
=
1e10
,
t_min
:
Optional
[
Tensor
]
=
None
,
# [n_rays]
t_max
:
Optional
[
Tensor
]
=
None
,
# [n_rays]
# rendering options
render_step_size
:
float
=
1e-3
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
stratified
:
bool
=
False
,
cone_angle
:
float
=
0.0
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Sampling with spatial skipping.
Note:
This function is not differentiable to any inputs.
Args:
rays_o: Ray origins of shape (n_rays, 3).
rays_d: Normalized ray directions of shape (n_rays, 3).
sigma_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `sigma_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation density values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
alpha_fn: Optional. If provided, the marching will skip the invisible space
by evaluating the density along the ray with `alpha_fn`. It should be a
function that takes in samples {t_starts (N,), t_ends (N,),
ray indices (N,)} and returns the post-activation opacity values (N,).
You should only provide either `sigma_fn` or `alpha_fn`.
near_plane: Optional. Near plane distance. Default: 0.0.
far_plane: Optional. Far plane distance. Default: 1e10.
t_min: Optional. Per-ray minimum distance. Tensor with shape (n_rays).
If profided, the marching will start from maximum of t_min and near_plane.
t_max: Optional. Per-ray maximum distance. Tensor with shape (n_rays).
If profided, the marching will stop by minimum of t_max and far_plane.
render_step_size: Step size for marching. Default: 1e-3.
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
stratified: Whether to use stratified sampling. Default: False.
cone_angle: Cone angle for linearly-increased step size. 0. means
constant step size. Default: 0.0.
Returns:
A tuple of {LongTensor, Tensor, Tensor}:
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_samples).
- **t_starts**: Per-sample start distance. Tensor with shape (n_samples,).
- **t_ends**: Per-sample end distance. Tensor with shape (n_samples,).
Examples:
.. code-block:: python
>>> ray_indices, t_starts, t_ends = grid.sampling(
>>> rays_o, rays_d, render_step_size=1e-3)
>>> t_mid = (t_starts + t_ends) / 2.0
>>> sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
assert
t_min
is
None
and
t_max
is
None
,
(
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
)
if
stratified
:
near_plane
+=
torch
.
rand
(()).
item
()
*
render_step_size
t_starts
,
t_ends
,
packed_info
,
ray_indices
=
svox
.
volume_sample
(
self
.
tree
,
thresh
=
self
.
thresh
,
rays
=
svox
.
Rays
(
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
()),
step_size
=
render_step_size
,
cone_angle
=
cone_angle
,
near_plane
=
near_plane
,
far_plane
=
far_plane
,
)
packed_info
=
packed_info
.
long
()
ray_indices
=
ray_indices
.
long
()
# skip invisible space
if
(
alpha_thre
>
0.0
or
early_stop_eps
>
0.0
)
and
(
sigma_fn
is
not
None
or
alpha_fn
is
not
None
):
alpha_thre
=
min
(
alpha_thre
,
self
.
thresh
)
# Compute visibility of the samples, and filter out invisible samples
if
sigma_fn
is
not
None
:
if
t_starts
.
shape
[
0
]
!=
0
:
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
)
else
:
sigmas
=
torch
.
empty
((
0
,),
device
=
t_starts
.
device
)
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N,)! Got {}"
.
format
(
sigmas
.
shape
)
masks
=
render_visibility_from_density
(
t_starts
=
t_starts
,
t_ends
=
t_ends
,
sigmas
=
sigmas
,
ray_indices
=
ray_indices
,
n_rays
=
len
(
rays_o
),
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
)
elif
alpha_fn
is
not
None
:
if
t_starts
.
shape
[
0
]
!=
0
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
)
else
:
alphas
=
torch
.
empty
((
0
,),
device
=
t_starts
.
device
)
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N,)! Got {}"
.
format
(
alphas
.
shape
)
masks
=
render_visibility_from_alpha
(
alphas
=
alphas
,
ray_indices
=
ray_indices
,
n_rays
=
len
(
rays_o
),
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
)
return
ray_indices
,
t_starts
,
t_ends
@
torch
.
no_grad
()
def
update_every_n_steps
(
self
,
step
:
int
,
occ_eval_fn
:
Callable
,
occ_thre
:
float
=
1e-2
,
ema_decay
:
float
=
0.95
,
warmup_steps
:
int
=
256
,
n
:
int
=
16
,
)
->
None
:
"""Update the estimator every n steps during training.
Args:
step: Current training step.
occ_eval_fn: A function that takes in sample locations :math:`(N, 3)` and
returns the occupancy values :math:`(N, 1)` at those locations.
occ_thre: Threshold used to binarize the occupancy grid. Default: 1e-2.
ema_decay: The decay rate for EMA updates. Default: 0.95.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 uniformly sampled cells
together with 1/4 occupied cells. Default: 256.
n: Update the grid every n steps. Default: 16.
"""
if
not
self
.
training
:
raise
RuntimeError
(
"You should only call this function only during training. "
"Please call _update() directly if you want to update the "
"field during inference."
)
if
step
%
n
==
0
and
self
.
training
:
self
.
_update
(
step
=
step
,
occ_eval_fn
=
occ_eval_fn
,
occ_thre
=
occ_thre
,
ema_decay
=
ema_decay
,
warmup_steps
=
warmup_steps
,
)
@
torch
.
no_grad
()
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tensor
]:
"""Samples both n uniform and occupied cells."""
uniform_indices
=
torch
.
randint
(
len
(
self
.
tree
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
torch
.
nonzero
(
self
.
tree
[:].
values
>=
self
.
thresh
)[:,
0
]
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
@
torch
.
no_grad
()
def
_update
(
self
,
step
:
int
,
occ_eval_fn
:
Callable
,
occ_thre
:
float
=
0.01
,
ema_decay
:
float
=
0.95
,
warmup_steps
:
int
=
256
,
)
->
None
:
"""Update the occ field in the EMA way."""
if
step
<
warmup_steps
:
x
=
self
.
tree
.
sample
(
1
).
squeeze
(
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
sel
=
(
*
self
.
tree
.
_all_leaves
().
T
,)
self
.
tree
.
data
.
data
[
sel
]
=
torch
.
maximum
(
self
.
tree
.
data
.
data
[
sel
]
*
ema_decay
,
occ
[:,
None
])
else
:
N
=
len
(
self
.
tree
)
//
4
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
x
=
self
.
tree
[
indices
].
sample
(
1
).
squeeze
(
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
self
.
tree
[
indices
]
=
torch
.
maximum
(
self
.
tree
[
indices
].
values
*
ema_decay
,
occ
[:,
None
]
)
self
.
thresh
=
min
(
occ_thre
,
self
.
tree
[:].
values
.
mean
().
item
())
if
__name__
==
"__main__"
:
roi_aabb
=
[
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
1.0
,
1.0
]
resolution
=
128
levels
=
4
estimator
=
N3TreeEstimator
(
roi_aabb
,
resolution
,
levels
)
def
occ_eval_fn
(
x
):
return
torch
.
rand
(
len
(
x
),
1
)
estimator
.
update_every_n_steps
(
0
,
occ_eval_fn
,
occ_thre
=
0.5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment