Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
b42d62cc
Commit
b42d62cc
authored
Nov 15, 2022
by
Ruilong Li
Browse files
fix tests; trainable; checking efficiency
parent
822d5199
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
40 deletions
+61
-40
examples/train_ngp_nerf_proposal.py
examples/train_ngp_nerf_proposal.py
+8
-4
nerfacc/ray_marching.py
nerfacc/ray_marching.py
+49
-32
tests/test_pdf_query.py
tests/test_pdf_query.py
+1
-1
tests/test_resampling.py
tests/test_resampling.py
+3
-3
No files found.
examples/train_ngp_nerf_proposal.py
View file @
b42d62cc
...
@@ -28,6 +28,7 @@ def set_random_seed(seed):
...
@@ -28,6 +28,7 @@ def set_random_seed(seed):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
# @profile
def
render_image
(
def
render_image
(
# scene
# scene
radiance_field
:
torch
.
nn
.
Module
,
radiance_field
:
torch
.
nn
.
Module
,
...
@@ -169,7 +170,7 @@ if __name__ == "__main__":
...
@@ -169,7 +170,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--test_chunk_size"
,
"--test_chunk_size"
,
type
=
int
,
type
=
int
,
default
=
8192
,
default
=
1024
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--unbounded"
,
"--unbounded"
,
...
@@ -264,8 +265,8 @@ if __name__ == "__main__":
...
@@ -264,8 +265,8 @@ if __name__ == "__main__":
hidden_dim
=
16
,
hidden_dim
=
16
,
max_res
=
64
,
max_res
=
64
,
geo_feat_dim
=
0
,
geo_feat_dim
=
0
,
n_levels
=
2
,
n_levels
=
4
,
log2_hashmap_size
=
1
7
,
log2_hashmap_size
=
1
9
,
),
),
# NGPradianceField(
# NGPradianceField(
# aabb=args.aabb,
# aabb=args.aabb,
...
@@ -303,6 +304,8 @@ if __name__ == "__main__":
...
@@ -303,6 +304,8 @@ if __name__ == "__main__":
for
epoch
in
range
(
10000000
):
for
epoch
in
range
(
10000000
):
for
i
in
range
(
len
(
train_dataset
)):
for
i
in
range
(
len
(
train_dataset
)):
radiance_field
.
train
()
radiance_field
.
train
()
proposal_nets
.
train
()
data
=
train_dataset
[
i
]
data
=
train_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
render_bkgd
=
data
[
"color_bkgd"
]
...
@@ -350,6 +353,7 @@ if __name__ == "__main__":
...
@@ -350,6 +353,7 @@ if __name__ == "__main__":
t_ends
,
t_ends
,
weights
,
weights
,
)
=
proposal_sample_list
[
-
1
]
)
=
proposal_sample_list
[
-
1
]
loss_interval
=
0.0
for
(
for
(
proposal_packed_info
,
proposal_packed_info
,
proposal_t_starts
,
proposal_t_starts
,
...
@@ -365,7 +369,6 @@ if __name__ == "__main__":
...
@@ -365,7 +369,6 @@ if __name__ == "__main__":
proposal_t_starts
,
proposal_t_starts
,
proposal_t_ends
,
proposal_t_ends
,
).
detach
()
).
detach
()
torch
.
cuda
.
synchronize
()
loss_interval
=
(
loss_interval
=
(
torch
.
clamp
(
proposal_weights_gt
-
proposal_weights
,
min
=
0
)
torch
.
clamp
(
proposal_weights_gt
-
proposal_weights
,
min
=
0
)
...
@@ -392,6 +395,7 @@ if __name__ == "__main__":
...
@@ -392,6 +395,7 @@ if __name__ == "__main__":
if
step
>=
0
and
step
%
1000
==
0
and
step
>
0
:
if
step
>=
0
and
step
%
1000
==
0
and
step
>
0
:
# evaluation
# evaluation
radiance_field
.
eval
()
radiance_field
.
eval
()
proposal_nets
.
eval
()
psnrs
=
[]
psnrs
=
[]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
nerfacc/ray_marching.py
View file @
b42d62cc
...
@@ -8,11 +8,12 @@ from .cdf import ray_resampling
...
@@ -8,11 +8,12 @@ from .cdf import ray_resampling
from
.contraction
import
ContractionType
from
.contraction
import
ContractionType
from
.grid
import
Grid
from
.grid
import
Grid
from
.intersection
import
ray_aabb_intersect
from
.intersection
import
ray_aabb_intersect
from
.pack
import
unpack_info
from
.pack
import
pack_info
,
unpack_info
from
.vol_rendering
import
render_visibility
,
render_weight_from_density
from
.vol_rendering
import
render_visibility
,
render_weight_from_density
@
torch
.
no_grad
()
@
torch
.
no_grad
()
# @profile
def
ray_marching
(
def
ray_marching
(
# rays
# rays
rays_o
:
torch
.
Tensor
,
rays_o
:
torch
.
Tensor
,
...
@@ -192,15 +193,60 @@ def ray_marching(
...
@@ -192,15 +193,60 @@ def ray_marching(
cone_angle
,
cone_angle
,
)
)
proposal_sample_list
=
[]
if
proposal_nets
is
not
None
:
if
proposal_nets
is
not
None
:
proposal_sample_list
=
[]
# resample with proposal nets
# resample with proposal nets
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
32
]):
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
64
]):
with
torch
.
no_grad
():
# skip invisible space
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
# Query sigma without gradients
if
sigma_fn
is
not
None
:
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
elif
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
min
(
alphas
.
mean
().
item
(),
1e-1
),
n_rays
=
rays_o
.
shape
[
0
],
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
)
# print(
# alphas.shape,
# masks.float().sum(),
# alphas.min(),
# alphas.max(),
# )
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
weights
=
render_weight_from_density
(
weights
=
render_weight_from_density
(
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
)
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
rays_o
.
shape
[
0
])
proposal_sample_list
.
append
(
proposal_sample_list
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
)
...
@@ -209,35 +255,6 @@ def ray_marching(
...
@@ -209,35 +255,6 @@ def ray_marching(
)
)
ray_indices
=
unpack_info
(
packed_info
,
n_samples
=
t_starts
.
shape
[
0
])
ray_indices
=
unpack_info
(
packed_info
,
n_samples
=
t_starts
.
shape
[
0
])
# skip invisible space
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
# Query sigma without gradients
if
sigma_fn
is
not
None
:
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
elif
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
n_rays
=
rays_o
.
shape
[
0
],
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
)
if
proposal_nets
is
not
None
:
if
proposal_nets
is
not
None
:
return
ray_indices
,
t_starts
,
t_ends
,
proposal_sample_list
return
ray_indices
,
t_starts
,
t_ends
,
proposal_sample_list
else
:
else
:
...
...
tests/test_pdf_query.py
View file @
b42d62cc
...
@@ -50,7 +50,7 @@ def test_pdf_query():
...
@@ -50,7 +50,7 @@ def test_pdf_query():
render_step_size
=
0.2
,
render_step_size
=
0.2
,
)
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
)
weights
=
torch
.
rand
((
t_starts
.
shape
[
0
],),
device
=
device
)
weights
=
torch
.
rand
((
t_starts
.
shape
[
0
],
1
),
device
=
device
)
packed_info_new
=
packed_info
packed_info_new
=
packed_info
t_starts_new
=
t_starts
-
0.3
t_starts_new
=
t_starts
-
0.3
...
...
tests/test_resampling.py
View file @
b42d62cc
...
@@ -105,8 +105,8 @@ def test_resampling():
...
@@ -105,8 +105,8 @@ def test_resampling():
t_starts
=
t
[:,
:
-
1
][
masks
].
unsqueeze
(
-
1
)
t_starts
=
t
[:,
:
-
1
][
masks
].
unsqueeze
(
-
1
)
t_ends
=
t
[:,
1
:][
masks
].
unsqueeze
(
-
1
)
t_ends
=
t
[:,
1
:][
masks
].
unsqueeze
(
-
1
)
w_logits
=
w_logits
[
masks
]
w_logits
=
w_logits
[
masks
]
.
unsqueeze
(
-
1
)
w
=
w
[
masks
]
w
=
w
[
masks
]
.
unsqueeze
(
-
1
)
num_steps
=
masks
.
long
().
sum
(
dim
=-
1
)
num_steps
=
masks
.
long
().
sum
(
dim
=-
1
)
cum_steps
=
torch
.
cumsum
(
num_steps
,
dim
=
0
)
cum_steps
=
torch
.
cumsum
(
num_steps
,
dim
=
0
)
packed_info
=
torch
.
stack
([
cum_steps
-
num_steps
,
num_steps
],
dim
=-
1
).
int
()
packed_info
=
torch
.
stack
([
cum_steps
-
num_steps
,
num_steps
],
dim
=-
1
).
int
()
...
@@ -143,7 +143,7 @@ def test_pdf_query():
...
@@ -143,7 +143,7 @@ def test_pdf_query():
)
)
packed_info
=
pack_info
(
ray_indices
,
rays_o
.
shape
[
0
])
packed_info
=
pack_info
(
ray_indices
,
rays_o
.
shape
[
0
])
weights
=
torch
.
rand
((
t_starts
.
shape
[
0
],),
device
=
device
)
weights
=
torch
.
rand
((
t_starts
.
shape
[
0
],
1
),
device
=
device
)
weights_new
=
ray_pdf_query
(
weights_new
=
ray_pdf_query
(
packed_info
,
packed_info
,
t_starts
,
t_starts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment