Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
e813bcaa
Commit
e813bcaa
authored
Sep 08, 2022
by
Ruilong Li
Browse files
occ field doc
parent
1ed5257d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
28 deletions
+65
-28
nerfacc/_grid.py
nerfacc/_grid.py
+0
-0
nerfacc/occupancy_field.py
nerfacc/occupancy_field.py
+65
-28
No files found.
nerfacc/grid.py
→
nerfacc/
_
grid.py
View file @
e813bcaa
File moved
nerfacc/occupancy_field.py
View file @
e813bcaa
...
...
@@ -3,33 +3,42 @@ from typing import Callable, List, Tuple, Union
import
torch
from
torch
import
nn
from
.grid
import
meshgrid
from
.
_
grid
import
meshgrid
class
OccupancyField
(
nn
.
Module
):
"""Occupancy Field."""
"""Occupancy Field that supports EMA updates.
It supports both 2D and 3D cases, where in the 2D cases the occupancy field
is basically a segmentation mask.
Args:
occ_eval_fn: A Callable function that takes in the un-normalized points x,
with shape of (N, 2) or (N, 3) (depends on `num_dim`), and outputs
the occupancy of those points with shape of (N, 1).
aabb: Scene bounding box. {min_x, min_y, (min_z), max_x, max_y, (max_z)}.
It can be either a list or a torch.Tensor.
resolution: The field resolution. It can either be a int of a list of ints
to specify resolution on each dimention. {res_x, res_y, (res_z)}. Default
is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3. Note other arguments
should match with the space dimension being set here.
"""
def
__init__
(
self
,
# Shape (N, 3) -> (N, 1). Values are in range [0, 1]: density * step_size
occ_eval_fn
:
Callable
,
aabb
:
Union
[
torch
.
Tensor
,
List
[
float
]],
resolution
:
Union
[
int
,
List
[
int
]]
,
# cell resolution
resolution
:
Union
[
int
,
List
[
int
]]
=
128
,
num_dim
:
int
=
3
,
)
->
None
:
# def occ_eval_fn(x):
# step_size = (rays.far - rays.near).max() / self.num_samples
# densities, _ = self.radiance_field.query_density(x)
# occ = densities * step_size
# return occ
super
().
__init__
()
self
.
occ_eval_fn
=
occ_eval_fn
if
not
isinstance
(
aabb
,
torch
.
Tensor
):
aabb
=
torch
.
tensor
(
aabb
,
dtype
=
torch
.
float32
)
if
not
isinstance
(
resolution
,
(
list
,
tuple
)):
resolution
=
[
resolution
]
*
num_dim
assert
num_dim
in
[
2
,
3
],
"Currently only supports 2D or 3D field."
assert
aabb
.
shape
==
(
num_dim
*
2
,
),
f
"shape of aabb (
{
aabb
.
shape
}
) should be num_dim * 2 (
{
num_dim
*
2
}
)."
...
...
@@ -45,7 +54,6 @@ class OccupancyField(nn.Module):
# Stores cell occupancy values ranged in [0, 1].
occ_grid
=
torch
.
zeros
(
self
.
num_cells
)
self
.
register_buffer
(
"occ_grid"
,
occ_grid
)
occ_grid_binary
=
torch
.
zeros
(
self
.
num_cells
,
dtype
=
torch
.
bool
)
self
.
register_buffer
(
"occ_grid_binary"
,
occ_grid_binary
)
...
...
@@ -53,23 +61,20 @@ class OccupancyField(nn.Module):
occ_grid_mean
=
occ_grid
.
mean
()
self
.
register_buffer
(
"occ_grid_mean"
,
occ_grid_mean
)
# Grid coords & indices
grid_coords
=
meshgrid
(
self
.
resolution
).
reshape
(
self
.
num_cells
,
self
.
num_dim
)
self
.
register_buffer
(
"grid_coords"
,
grid_coords
)
grid_indices
=
torch
.
arange
(
self
.
num_cells
)
self
.
register_buffer
(
"grid_indices"
,
grid_indices
)
@
torch
.
no_grad
()
def
get_all_cells
(
self
,
)
->
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
def
_get_all_cells
(
self
)
->
torch
.
Tensor
:
"""Returns all cells of the grid."""
return
self
.
grid_indices
@
torch
.
no_grad
()
def
sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""Samples both n uniform and occupied cells (per level)."""
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
torch
.
Tensor
:
"""Samples both n uniform and occupied cells."""
device
=
self
.
occ_grid
.
device
uniform_indices
=
torch
.
randint
(
self
.
num_cells
,
(
n
,),
device
=
device
)
...
...
@@ -88,16 +93,26 @@ class OccupancyField(nn.Module):
step
:
int
,
occ_threshold
:
float
=
0.01
,
ema_decay
:
float
=
0.95
,
warmup_steps
:
int
=
256
,
)
->
None
:
"""Update the occ_grid (as well as occ_bitfield) in EMA way."""
"""Update the occ field in the EMA way.
Args:
step: Current training step.
occ_threshold: Threshold to binarize the occupancy field.
ema_decay: The decay rate for EMA updates.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
"""
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
# sample cells
if
step
<
256
:
indices
=
self
.
get_all_cells
()
if
step
<
warmup_steps
:
indices
=
self
.
_
get_all_cells
()
else
:
N
=
resolution
.
prod
().
item
()
//
4
indices
=
self
.
sample_uniform_and_occupied_cells
(
N
)
indices
=
self
.
_
sample_uniform_and_occupied_cells
(
N
)
# infer occupancy: density * step_size
tmp_occ_grid
=
-
torch
.
ones_like
(
self
.
occ_grid
)
...
...
@@ -118,8 +133,19 @@ class OccupancyField(nn.Module):
)
@
torch
.
no_grad
()
def
query_occ
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Query the occ_grid."""
def
query_occ
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Query the occupancy, given samples.
Args:
x: Samples with shape (..., 2) or (..., 3).
Returns:
float occupancy values with shape (...),
binary occupancy values with shape (...)
"""
assert
(
x
.
shape
[
-
1
]
==
self
.
num_dim
),
"The samples are not drawn from a proper space!"
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
...
...
@@ -144,10 +170,21 @@ class OccupancyField(nn.Module):
return
occs
,
occs_binary
@
torch
.
no_grad
()
def
every_n_step
(
self
,
step
:
int
,
n
:
int
=
16
):
def
every_n_step
(
self
,
step
:
int
,
occ_thre
:
float
=
1e-2
,
ema_decay
:
float
=
0.95
,
n
:
int
=
16
,
):
if
not
self
.
training
:
raise
RuntimeError
(
"You should only call this function during training. Please call update() "
"directly if you want to update the field during inference."
)
if
step
%
n
==
0
and
self
.
training
:
self
.
update
(
step
=
step
,
occ_threshold
=
0.01
,
ema_decay
=
0.95
,
occ_threshold
=
occ_thre
,
ema_decay
=
ema_decay
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment