Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
5009fc12
Commit
5009fc12
authored
Sep 08, 2022
by
Ruilong Li
Browse files
add occ field
parent
298ffd02
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
353 additions
and
97 deletions
+353
-97
nerfacc/__init__.py
nerfacc/__init__.py
+2
-96
nerfacc/grid.py
nerfacc/grid.py
+101
-0
nerfacc/occupancy_field.py
nerfacc/occupancy_field.py
+153
-0
nerfacc/volumetric_rendering.py
nerfacc/volumetric_rendering.py
+96
-0
setup.py
setup.py
+1
-1
No files found.
nerfacc/__init__.py
View file @
5009fc12
import
math
from
.occupancy_field
import
OccupancyField
from
typing
import
Callable
,
Tuple
from
.volumetric_rendering
import
volumetric_rendering
import
torch
from
.cuda
import
VolumeRenderer
,
ray_aabb_intersect
,
ray_marching
def
volumetric_rendering
(
query_fn
:
Callable
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
scene_aabb
:
torch
.
Tensor
,
scene_occ_binary
:
torch
.
Tensor
,
scene_resolution
:
Tuple
[
int
,
int
,
int
],
render_bkgd
:
torch
.
Tensor
=
None
,
render_n_samples
:
int
=
1024
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""A *fast* version of differentiable volumetric rendering."""
device
=
rays_o
.
device
if
render_bkgd
is
None
:
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
scene_resolution
=
torch
.
tensor
(
scene_resolution
,
dtype
=
torch
.
int
,
device
=
device
)
rays_o
=
rays_o
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
scene_aabb
=
scene_aabb
.
contiguous
()
scene_occ_binary
=
scene_occ_binary
.
contiguous
()
render_bkgd
=
render_bkgd
.
contiguous
()
n_rays
=
rays_o
.
shape
[
0
]
render_total_samples
=
n_rays
*
render_n_samples
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
with
torch
.
no_grad
():
# TODO: avoid clamp here. kinda stupid
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
t_min
=
torch
.
clamp
(
t_min
,
max
=
1e10
)
t_max
=
torch
.
clamp
(
t_max
,
max
=
1e10
)
(
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
)
=
ray_marching
(
# rays
rays_o
,
rays_d
,
t_min
,
t_max
,
# density grid
scene_aabb
,
scene_resolution
,
scene_occ_binary
,
# sampling
render_total_samples
,
render_n_samples
,
render_step_size
,
)
# squeeze valid samples
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
frustum_origins
=
frustum_origins
[:
total_samples
]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
frustum_starts
=
frustum_starts
[:
total_samples
]
frustum_ends
=
frustum_ends
[:
total_samples
]
frustum_positions
=
(
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
query_results
=
query_fn
(
frustum_positions
,
frustum_dirs
,
**
kwargs
)
rgbs
,
densities
=
query_results
[
0
],
query_results
[
1
]
(
accumulated_weight
,
accumulated_depth
,
accumulated_color
,
alive_ray_mask
,
)
=
VolumeRenderer
.
apply
(
packed_info
,
frustum_starts
,
frustum_ends
,
densities
.
contiguous
(),
rgbs
.
contiguous
(),
)
accumulated_depth
=
torch
.
clip
(
accumulated_depth
,
t_min
[:,
None
],
t_max
[:,
None
])
accumulated_color
=
accumulated_color
+
render_bkgd
*
(
1.0
-
accumulated_weight
)
return
accumulated_color
,
accumulated_depth
,
accumulated_weight
,
alive_ray_mask
nerfacc/grid.py
0 → 100644
View file @
5009fc12
from
typing
import
List
,
Tuple
import
torch
import
torch.nn.functional
as
F
def
query_grid
(
x
:
torch
.
Tensor
,
aabb
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Query values in the grid field given the coordinates.
Args:
x: 2D / 3D coordinates, with shape of [..., 2 or 3]
aabb: 2D / 3D bounding box of the field, with shape of [4 or 6]
grid: Grid with shape [res_x, res_y, res_z, D] or [res_x, res_y, D]
Returns:
values with shape [..., D]
"""
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
grid
.
shape
[
-
1
]]
if
x
.
shape
[
-
1
]
==
2
and
aabb
.
shape
==
(
4
,)
and
grid
.
ndim
==
3
:
# 2D case
grid
=
grid
.
permute
(
2
,
1
,
0
).
unsqueeze
(
0
)
# [1, D, res_y, res_x]
x
=
(
x
.
view
(
1
,
-
1
,
1
,
2
)
-
aabb
[
0
:
2
])
/
(
aabb
[
2
:
4
]
-
aabb
[
0
:
2
])
elif
x
.
shape
[
-
1
]
==
3
and
aabb
.
shape
==
(
6
,)
and
grid
.
ndim
==
4
:
# 3D case
grid
=
grid
.
permute
(
3
,
2
,
1
,
0
).
unsqueeze
(
0
)
# [1, D, res_z, res_y, res_x]
x
=
(
x
.
view
(
1
,
-
1
,
1
,
1
,
3
)
-
aabb
[
0
:
3
])
/
(
aabb
[
3
:
6
]
-
aabb
[
0
:
3
])
else
:
raise
ValueError
(
"The shapes of the inputs do not match to either 2D case or 3D case! "
f
"Got x:
{
x
.
shape
}
; aabb:
{
aabb
.
shape
}
; grid:
{
grid
.
shape
}
."
)
v
=
F
.
grid_sample
(
grid
,
x
*
2.0
-
1.0
,
align_corners
=
True
,
padding_mode
=
"zeros"
,
)
v
=
v
.
reshape
(
output_shape
)
return
v
def
meshgrid
(
resolution
:
List
[
int
]):
if
len
(
resolution
)
==
2
:
return
meshgrid2d
(
resolution
)
elif
len
(
resolution
)
==
3
:
return
meshgrid3d
(
resolution
)
else
:
raise
ValueError
(
resolution
)
def
meshgrid2d
(
res
:
Tuple
[
int
,
int
],
device
:
torch
.
device
=
"cpu"
):
"""Create 2D grid coordinates.
Args:
res (Tuple[int, int]): resolutions for {x, y} dimensions.
Returns:
torch.long with shape (res[0], res[1], 2): dense 2D grid coordinates.
"""
return
(
torch
.
stack
(
torch
.
meshgrid
(
[
torch
.
arange
(
res
[
0
]),
torch
.
arange
(
res
[
1
]),
],
indexing
=
"ij"
,
),
dim
=-
1
,
)
.
long
()
.
to
(
device
)
)
def
meshgrid3d
(
res
:
Tuple
[
int
,
int
,
int
],
device
:
torch
.
device
=
"cpu"
):
"""Create 3D grid coordinates.
Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.
Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
"""
return
(
torch
.
stack
(
torch
.
meshgrid
(
[
torch
.
arange
(
res
[
0
]),
torch
.
arange
(
res
[
1
]),
torch
.
arange
(
res
[
2
]),
],
indexing
=
"ij"
,
),
dim
=-
1
,
)
.
long
()
.
to
(
device
)
)
nerfacc/occupancy_field.py
0 → 100644
View file @
5009fc12
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
.grid
import
meshgrid
class
OccupancyField
(
nn
.
Module
):
"""Occupancy Field."""
def
__init__
(
self
,
# Shape (N, 3) -> (N, 1). Values are in range [0, 1]: density * step_size
occ_eval_fn
:
Callable
,
aabb
:
Union
[
torch
.
Tensor
,
List
[
float
]],
resolution
:
Union
[
int
,
List
[
int
]],
# cell resolution
num_dim
:
int
=
3
,
)
->
None
:
# def occ_eval_fn(x):
# step_size = (rays.far - rays.near).max() / self.num_samples
# densities, _ = self.radiance_field.query_density(x)
# occ = densities * step_size
# return occ
super
().
__init__
()
self
.
occ_eval_fn
=
occ_eval_fn
if
not
isinstance
(
aabb
,
torch
.
Tensor
):
aabb
=
torch
.
tensor
(
aabb
,
dtype
=
torch
.
float32
)
if
not
isinstance
(
resolution
,
(
list
,
tuple
)):
resolution
=
[
resolution
]
*
num_dim
assert
aabb
.
shape
==
(
num_dim
*
2
,
),
f
"shape of aabb (
{
aabb
.
shape
}
) should be num_dim * 2 (
{
num_dim
*
2
}
)."
assert
(
len
(
resolution
)
==
num_dim
),
f
"length of resolution (
{
len
(
resolution
)
}
) should be num_dim (
{
num_dim
}
)."
self
.
register_buffer
(
"aabb"
,
aabb
)
self
.
resolution
=
resolution
self
.
num_dim
=
num_dim
self
.
num_cells
=
torch
.
tensor
(
resolution
).
prod
().
item
()
# Stores cell occupancy values ranged in [0, 1].
occ_grid
=
torch
.
zeros
(
self
.
num_cells
)
self
.
register_buffer
(
"occ_grid"
,
occ_grid
)
occ_grid_binary
=
torch
.
zeros
(
self
.
num_cells
,
dtype
=
torch
.
bool
)
self
.
register_buffer
(
"occ_grid_binary"
,
occ_grid_binary
)
# Used for thresholding occ_grid
occ_grid_mean
=
occ_grid
.
mean
()
self
.
register_buffer
(
"occ_grid_mean"
,
occ_grid_mean
)
grid_coords
=
meshgrid
(
self
.
resolution
).
reshape
(
self
.
num_cells
,
self
.
num_dim
)
self
.
register_buffer
(
"grid_coords"
,
grid_coords
)
grid_indices
=
torch
.
arange
(
self
.
num_cells
)
self
.
register_buffer
(
"grid_indices"
,
grid_indices
)
@
torch
.
no_grad
()
def
get_all_cells
(
self
,
)
->
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""Returns all cells of the grid."""
return
self
.
grid_indices
@
torch
.
no_grad
()
def
sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""Samples both n uniform and occupied cells (per level)."""
device
=
self
.
occ_grid
.
device
uniform_indices
=
torch
.
randint
(
self
.
num_cells
,
(
n
,),
device
=
device
)
occupied_indices
=
torch
.
nonzero
(
self
.
occ_grid_binary
)[:,
0
]
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
device
)
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
@
torch
.
no_grad
()
def
update
(
self
,
step
:
int
,
occ_threshold
:
float
=
0.01
,
ema_decay
:
float
=
0.95
,
)
->
None
:
"""Update the occ_grid (as well as occ_bitfield) in EMA way."""
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
# sample cells
if
step
<
256
:
indices
=
self
.
get_all_cells
()
else
:
N
=
resolution
.
prod
().
item
()
//
4
indices
=
self
.
sample_uniform_and_occupied_cells
(
N
)
# infer occupancy: density * step_size
tmp_occ_grid
=
-
torch
.
ones_like
(
self
.
occ_grid
)
grid_coords
=
self
.
grid_coords
[
indices
]
x
=
(
grid_coords
+
torch
.
rand_like
(
grid_coords
.
float
()))
/
resolution
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
x
*
(
bb_max
-
bb_min
)
+
bb_min
tmp_occ_grid
[
indices
]
=
self
.
occ_eval_fn
(
x
).
squeeze
(
-
1
)
# ema update
ema_mask
=
(
self
.
occ_grid
>=
0
)
&
(
tmp_occ_grid
>=
0
)
self
.
occ_grid
[
ema_mask
]
=
torch
.
maximum
(
self
.
occ_grid
[
ema_mask
]
*
ema_decay
,
tmp_occ_grid
[
ema_mask
]
)
self
.
occ_grid_mean
=
self
.
occ_grid
.
mean
()
self
.
occ_grid_binary
=
self
.
occ_grid
>
min
(
self
.
occ_grid_mean
.
item
(),
occ_threshold
)
@
torch
.
no_grad
()
def
query_occ
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Query the occ_grid."""
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
(
x
-
bb_min
)
/
(
bb_max
-
bb_min
)
selector
=
((
x
>
0.0
)
&
(
x
<
1.0
)).
all
(
dim
=-
1
)
grid_coords
=
torch
.
floor
(
x
*
resolution
).
long
()
if
self
.
num_dim
==
2
:
grid_indices
=
(
grid_coords
[...,
0
]
*
self
.
resolution
[
-
1
]
+
grid_coords
[...,
1
]
)
elif
self
.
num_dim
==
3
:
grid_indices
=
(
grid_coords
[...,
0
]
*
self
.
resolution
[
-
1
]
*
self
.
resolution
[
-
2
]
+
grid_coords
[...,
1
]
*
self
.
resolution
[
-
1
]
+
grid_coords
[...,
2
]
)
occs
=
torch
.
zeros
(
x
.
shape
[:
-
1
],
device
=
x
.
device
)
occs
[
selector
]
=
self
.
occ_grid
[
grid_indices
[
selector
]]
occs_binary
=
torch
.
zeros
(
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
bool
)
occs_binary
[
selector
]
=
self
.
occ_grid_binary
[
grid_indices
[
selector
]]
return
occs
,
occs_binary
@
torch
.
no_grad
()
def
every_n_step
(
self
,
step
:
int
,
n
:
int
=
16
):
if
step
%
n
==
0
and
self
.
training
:
self
.
update
(
step
=
step
,
occ_threshold
=
0.01
,
ema_decay
=
0.95
,
)
nerfacc/volumetric_rendering.py
0 → 100644
View file @
5009fc12
import
math
from
typing
import
Callable
,
Tuple
import
torch
from
.cuda
import
VolumeRenderer
,
ray_aabb_intersect
,
ray_marching
def
volumetric_rendering
(
query_fn
:
Callable
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
scene_aabb
:
torch
.
Tensor
,
scene_occ_binary
:
torch
.
Tensor
,
scene_resolution
:
Tuple
[
int
,
int
,
int
],
render_bkgd
:
torch
.
Tensor
=
None
,
render_n_samples
:
int
=
1024
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""A *fast* version of differentiable volumetric rendering."""
device
=
rays_o
.
device
if
render_bkgd
is
None
:
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
scene_resolution
=
torch
.
tensor
(
scene_resolution
,
dtype
=
torch
.
int
,
device
=
device
)
rays_o
=
rays_o
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
scene_aabb
=
scene_aabb
.
contiguous
()
scene_occ_binary
=
scene_occ_binary
.
contiguous
()
render_bkgd
=
render_bkgd
.
contiguous
()
n_rays
=
rays_o
.
shape
[
0
]
render_total_samples
=
n_rays
*
render_n_samples
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
with
torch
.
no_grad
():
# TODO: avoid clamp here. kinda stupid
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
t_min
=
torch
.
clamp
(
t_min
,
max
=
1e10
)
t_max
=
torch
.
clamp
(
t_max
,
max
=
1e10
)
(
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
)
=
ray_marching
(
# rays
rays_o
,
rays_d
,
t_min
,
t_max
,
# density grid
scene_aabb
,
scene_resolution
,
scene_occ_binary
,
# sampling
render_total_samples
,
render_n_samples
,
render_step_size
,
)
# squeeze valid samples
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
frustum_origins
=
frustum_origins
[:
total_samples
]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
frustum_starts
=
frustum_starts
[:
total_samples
]
frustum_ends
=
frustum_ends
[:
total_samples
]
frustum_positions
=
(
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
query_results
=
query_fn
(
frustum_positions
,
frustum_dirs
,
**
kwargs
)
rgbs
,
densities
=
query_results
[
0
],
query_results
[
1
]
(
accumulated_weight
,
accumulated_depth
,
accumulated_color
,
alive_ray_mask
,
)
=
VolumeRenderer
.
apply
(
packed_info
,
frustum_starts
,
frustum_ends
,
densities
.
contiguous
(),
rgbs
.
contiguous
(),
)
accumulated_depth
=
torch
.
clip
(
accumulated_depth
,
t_min
[:,
None
],
t_max
[:,
None
])
accumulated_color
=
accumulated_color
+
render_bkgd
*
(
1.0
-
accumulated_weight
)
return
accumulated_color
,
accumulated_depth
,
accumulated_weight
,
alive_ray_mask
setup.py
View file @
5009fc12
...
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
...
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
setup
(
setup
(
name
=
"nerfacc"
,
name
=
"nerfacc"
,
description
=
"NeRF accelerated rendering"
,
description
=
"NeRF accelerated rendering"
,
version
=
"0.0.
2
"
,
version
=
"0.0.
3
"
,
python_requires
=
">=3.9"
,
python_requires
=
">=3.9"
,
packages
=
find_packages
(
exclude
=
(
"tests*"
,)),
packages
=
find_packages
(
exclude
=
(
"tests*"
,)),
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment