Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
a74f7e7b
Commit
a74f7e7b
authored
Nov 20, 2022
by
Ruilong Li
Browse files
proposal_sampling_with_filter: 7k; 229s; loss 64; 35.25db; 63 rays
parent
1aeee0a9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
220 additions
and
127 deletions
+220
-127
examples/train_ngp_nerf_proposal.py
examples/train_ngp_nerf_proposal.py
+9
-2
nerfacc/ray_marching.py
nerfacc/ray_marching.py
+24
-125
nerfacc/sampling.py
nerfacc/sampling.py
+187
-0
No files found.
examples/train_ngp_nerf_proposal.py
View file @
a74f7e7b
...
@@ -87,7 +87,15 @@ def render_image(
...
@@ -87,7 +87,15 @@ def render_image(
chunk_rays
.
viewdirs
,
chunk_rays
.
viewdirs
,
scene_aabb
=
scene_aabb
,
scene_aabb
=
scene_aabb
,
grid
=
None
,
grid
=
None
,
proposal_nets
=
proposal_nets
,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns
=
[
lambda
t_starts
,
t_ends
,
ray_indices
:
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
,
proposal_net
)
for
proposal_net
in
proposal_nets
],
proposal_n_samples
=
[
32
],
proposal_require_grads
=
proposal_nets_require_grads
,
sigma_fn
=
sigma_fn
,
sigma_fn
=
sigma_fn
,
near_plane
=
near_plane
,
near_plane
=
near_plane
,
far_plane
=
far_plane
,
far_plane
=
far_plane
,
...
@@ -95,7 +103,6 @@ def render_image(
...
@@ -95,7 +103,6 @@ def render_image(
stratified
=
radiance_field
.
training
,
stratified
=
radiance_field
.
training
,
cone_angle
=
cone_angle
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
proposal_nets_require_grads
=
proposal_nets_require_grads
,
)
)
rgb
,
opacity
,
depth
,
weights
=
rendering
(
rgb
,
opacity
,
depth
,
weights
=
rendering
(
t_starts
,
t_starts
,
...
...
nerfacc/ray_marching.py
View file @
a74f7e7b
...
@@ -2,58 +2,9 @@ from typing import Callable, Optional, Tuple
...
@@ -2,58 +2,9 @@ from typing import Callable, Optional, Tuple
import
torch
import
torch
import
nerfacc.cuda
as
_C
from
.cdf
import
ray_resampling
from
.grid
import
Grid
from
.grid
import
Grid
from
.intersection
import
ray_aabb_intersect
from
.intersection
import
ray_aabb_intersect
from
.pack
import
pack_info
,
unpack_info
from
.sampling
import
proposal_sampling_with_filter
,
sample_along_rays
from
.vol_rendering
import
(
render_visibility
,
render_weight_from_alpha
,
render_weight_from_density
,
)
@
torch
.
no_grad
()
def
maybe_filter
(
t_starts
:
torch
.
Tensor
,
t_ends
:
torch
.
Tensor
,
ray_indices
:
torch
.
Tensor
,
n_rays
:
int
,
# sigma/alpha function for skipping invisible space
sigma_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
net
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
):
alphas
=
None
if
sigma_fn
is
not
None
:
alpha_fn
=
lambda
*
args
:
1.0
-
torch
.
exp
(
-
sigma_fn
(
*
args
)
*
(
t_ends
-
t_starts
)
)
if
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
)
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
n_rays
=
n_rays
,
)
ray_indices
,
t_starts
,
t_ends
,
alphas
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
alphas
[
masks
],
)
return
ray_indices
,
t_starts
,
t_ends
,
alphas
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -71,10 +22,12 @@ def ray_marching(
...
@@ -71,10 +22,12 @@ def ray_marching(
# sigma/alpha function for skipping invisible space
# sigma/alpha function for skipping invisible space
sigma_fn
:
Optional
[
Callable
]
=
None
,
sigma_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
proposal_nets
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns
:
Tuple
[
Callable
,
...]
=
[],
proposal_n_samples
:
Tuple
[
int
,
...]
=
[],
proposal_require_grads
:
bool
=
False
,
early_stop_eps
:
float
=
1e-4
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
proposal_nets_require_grads
:
bool
=
True
,
# rendering options
# rendering options
near_plane
:
Optional
[
float
]
=
None
,
near_plane
:
Optional
[
float
]
=
None
,
far_plane
:
Optional
[
float
]
=
None
,
far_plane
:
Optional
[
float
]
=
None
,
...
@@ -177,7 +130,6 @@ def ray_marching(
...
@@ -177,7 +130,6 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
"""
torch
.
cuda
.
synchronize
()
n_rays
=
rays_o
.
shape
[
0
]
n_rays
=
rays_o
.
shape
[
0
]
if
not
rays_o
.
is_cuda
:
if
not
rays_o
.
is_cuda
:
...
@@ -209,85 +161,32 @@ def ray_marching(
...
@@ -209,85 +161,32 @@ def ray_marching(
if
stratified
:
if
stratified
:
t_min
=
t_min
+
torch
.
rand_like
(
t_min
)
*
render_step_size
t_min
=
t_min
+
torch
.
rand_like
(
t_min
)
*
render_step_size
# use grid for skipping if given
ray_indices
,
t_starts
,
t_ends
=
sample_along_rays
(
if
grid
is
not
None
:
rays_o
=
rays_o
,
# marching with grid-based skipping
rays_d
=
rays_d
,
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching_with_grid
(
t_min
=
t_min
,
# rays
t_max
=
t_max
,
rays_o
.
contiguous
(),
step_size
=
render_step_size
,
rays_d
.
contiguous
(),
cone_angle
=
cone_angle
,
t_min
.
contiguous
(),
grid
=
grid
,
t_max
.
contiguous
(),
)
# coontraction and grid
grid
.
roi_aabb
.
contiguous
(),
grid
.
binary
.
contiguous
(),
grid
.
contraction_type
.
to_cpp_version
(),
# sampling
render_step_size
,
cone_angle
,
)
else
:
# marching
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
# rays
t_min
.
contiguous
(),
t_max
.
contiguous
(),
# sampling
render_step_size
,
cone_angle
,
)
proposal_sample_list
=
[]
if
proposal_nets
is
not
None
:
# resample with proposal nets
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
32
]):
ray_indices
,
t_starts
,
t_ends
,
alphas
=
maybe_filter
(
t_starts
=
t_starts
,
t_ends
=
t_ends
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
,
sigma_fn
=
sigma_fn
,
alpha_fn
=
alpha_fn
,
net
=
net
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
n_rays
)
if
proposal_nets_require_grads
:
with
torch
.
enable_grad
():
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
weights
=
render_weight_from_density
(
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
)
proposal_sample_list
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
else
:
weights
=
render_weight_from_alpha
(
alphas
,
ray_indices
=
ray_indices
)
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
num_samples
)
ray_indices
=
unpack_info
(
packed_info
,
n_samples
=
t_starts
.
shape
[
0
])
ray_indices
,
t_starts
,
t_ends
,
_
=
maybe_filter
(
(
ray_indices
,
t_starts
,
t_ends
,
proposal_samples
,
)
=
proposal_sampling_with_filter
(
t_starts
=
t_starts
,
t_starts
=
t_starts
,
t_ends
=
t_ends
,
t_ends
=
t_ends
,
ray_indices
=
ray_indices
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
,
n_rays
=
n_rays
,
sigma_fn
=
sigma_fn
,
sigma_fn
=
sigma_fn
,
alpha_fn
=
alpha_fn
,
proposal_sigma_fns
=
proposal_sigma_fns
,
net
=
None
,
proposal_n_samples
=
proposal_n_samples
,
proposal_require_grads
=
proposal_require_grads
,
early_stop_eps
=
early_stop_eps
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
)
)
if
proposal_nets
is
not
None
:
return
ray_indices
,
t_starts
,
t_ends
,
proposal_samples
return
ray_indices
,
t_starts
,
t_ends
,
proposal_sample_list
else
:
return
ray_indices
,
t_starts
,
t_ends
nerfacc/sampling.py
0 → 100644
View file @
a74f7e7b
import
math
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
,
overload
import
torch
import
nerfacc.cuda
as
_C
from
.cdf
import
ray_resampling
from
.grid
import
Grid
from
.pack
import
pack_info
,
unpack_info
from
.vol_rendering
import
(
render_transmittance_from_alpha
,
render_weight_from_density
,
)
@
overload
def
sample_along_rays
(
rays_o
:
torch
.
Tensor
,
# [n_rays, 3]
rays_d
:
torch
.
Tensor
,
# [n_rays, 3]
t_min
:
torch
.
Tensor
,
# [n_rays,]
t_max
:
torch
.
Tensor
,
# [n_rays,]
step_size
:
float
,
cone_angle
:
float
=
0.0
,
grid
:
Optional
[
Grid
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Sample along rays with per-ray min max."""
...
@
overload
def
sample_along_rays
(
rays_o
:
torch
.
Tensor
,
# [n_rays, 3]
rays_d
:
torch
.
Tensor
,
# [n_rays, 3]
t_min
:
float
,
t_max
:
float
,
step_size
:
float
,
cone_angle
:
float
=
0.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Sample along rays with near far plane."""
...
@
torch
.
no_grad
()
def
sample_along_rays
(
rays_o
:
torch
.
Tensor
,
# [n_rays, 3]
rays_d
:
torch
.
Tensor
,
# [n_rays, 3]
t_min
:
Union
[
float
,
torch
.
Tensor
],
# [n_rays,]
t_max
:
Union
[
float
,
torch
.
Tensor
],
# [n_rays,]
step_size
:
float
,
cone_angle
:
float
=
0.0
,
grid
:
Optional
[
Grid
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Sample intervals along rays."""
if
isinstance
(
t_min
,
float
)
and
isinstance
(
t_max
,
float
):
n_rays
=
rays_o
.
shape
[
0
]
device
=
rays_o
.
device
num_steps
=
math
.
floor
((
t_max
-
t_min
)
/
step_size
)
t_starts
=
(
(
t_min
+
torch
.
arange
(
0
,
num_steps
,
device
=
device
)
*
step_size
)
.
expand
(
n_rays
,
-
1
)
.
reshape
(
-
1
,
1
)
)
t_ends
=
t_starts
+
step_size
ray_indices
=
torch
.
arange
(
0
,
n_rays
,
device
=
device
).
repeat_interleave
(
num_steps
,
dim
=
0
)
else
:
if
grid
is
None
:
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
# rays
t_min
.
contiguous
(),
t_max
.
contiguous
(),
# sampling
step_size
,
cone_angle
,
)
else
:
(
packed_info
,
ray_indices
,
t_starts
,
t_ends
,
)
=
_C
.
ray_marching_with_grid
(
# rays
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
t_min
.
contiguous
(),
t_max
.
contiguous
(),
# coontraction and grid
grid
.
roi_aabb
.
contiguous
(),
grid
.
binary
.
contiguous
(),
grid
.
contraction_type
.
to_cpp_version
(),
# sampling
step_size
,
cone_angle
,
)
return
ray_indices
,
t_starts
,
t_ends
@
torch
.
no_grad
()
def
proposal_sampling_with_filter
(
t_starts
:
torch
.
Tensor
,
# [n_samples, 1]
t_ends
:
torch
.
Tensor
,
# [n_samples, 1]
ray_indices
:
torch
.
Tensor
,
# [n_samples,]
n_rays
:
Optional
[
int
]
=
None
,
# compute density of samples: {t_starts, t_ends, ray_indices} -> density
sigma_fn
:
Optional
[
Callable
]
=
None
,
# proposal density fns: {t_starts, t_ends, ray_indices} -> density
proposal_sigma_fns
:
Tuple
[
Callable
,
...]
=
[],
proposal_n_samples
:
Tuple
[
int
,
...]
=
[],
proposal_require_grads
:
bool
=
False
,
# acceleration options
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Hueristic marching with proposal fns."""
assert
len
(
proposal_sigma_fns
)
==
len
(
proposal_n_samples
),
(
"proposal_sigma_fns and proposal_n_samples must have the same length, "
f
"but got
{
len
(
proposal_sigma_fns
)
}
and
{
len
(
proposal_n_samples
)
}
."
)
if
n_rays
is
None
:
n_rays
=
ray_indices
.
max
()
+
1
# compute density from proposal fns
proposal_samples
=
[]
for
proposal_fn
,
n_samples
in
zip
(
proposal_sigma_fns
,
proposal_n_samples
):
# compute weights for resampling
sigmas
=
proposal_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
transmittance
=
render_transmittance_from_alpha
(
alphas
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
)
weights
=
alphas
*
transmittance
# Compute visibility for filtering
if
alpha_thre
>
0
or
early_stop_eps
>
0
:
vis
=
(
alphas
>=
alpha_thre
)
&
(
transmittance
>=
early_stop_eps
)
vis
=
vis
.
squeeze
(
-
1
)
ray_indices
,
t_starts
,
t_ends
,
weights
=
(
ray_indices
[
vis
],
t_starts
[
vis
],
t_ends
[
vis
],
weights
[
vis
],
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
n_rays
)
# Rerun the proposal function **with** gradients on filtered samples.
if
proposal_require_grads
:
with
torch
.
enable_grad
():
sigmas
=
proposal_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
weights
=
render_weight_from_density
(
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
)
proposal_samples
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
# resampling on filtered samples
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
n_samples
)
ray_indices
=
unpack_info
(
packed_info
,
t_starts
.
shape
[
0
])
# last round filtering with sigma_fn
if
(
alpha_thre
>
0
or
early_stop_eps
>
0
)
and
(
sigma_fn
is
not
None
):
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
transmittance
=
render_transmittance_from_alpha
(
alphas
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
)
vis
=
(
alphas
>=
alpha_thre
)
&
(
transmittance
>=
early_stop_eps
)
vis
=
vis
.
squeeze
(
-
1
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
vis
],
t_starts
[
vis
],
t_ends
[
vis
],
)
return
ray_indices
,
t_starts
,
t_ends
,
proposal_samples
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment