Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
3eef9208
Commit
3eef9208
authored
Jan 02, 2024
by
Ruilong Li
Browse files
format
parent
88a6aec6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
20 deletions
+59
-20
examples/datasets/nerf_360_v2.py
examples/datasets/nerf_360_v2.py
+13
-3
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+13
-3
examples/train_ngp_nerf_occ.py
examples/train_ngp_nerf_occ.py
+9
-3
nerfacc/estimators/n3tree.py
nerfacc/estimators/n3tree.py
+24
-11
No files found.
examples/datasets/nerf_360_v2.py
View file @
3eef9208
...
@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
elif
self
.
color_bkgd_aug
==
"white"
:
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
elif
self
.
color_bkgd_aug
==
"black"
:
...
@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else
:
else
:
image_id
=
[
index
]
*
num_rays
image_id
=
[
index
]
*
num_rays
x
=
torch
.
randint
(
x
=
torch
.
randint
(
0
,
self
.
width
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
width
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
y
=
torch
.
randint
(
y
=
torch
.
randint
(
0
,
self
.
height
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
height
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
else
:
else
:
image_id
=
[
index
]
image_id
=
[
index
]
...
...
examples/datasets/nerf_synthetic.py
View file @
3eef9208
...
@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
elif
self
.
color_bkgd_aug
==
"white"
:
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
elif
self
.
color_bkgd_aug
==
"black"
:
...
@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else
:
else
:
image_id
=
[
index
]
*
num_rays
image_id
=
[
index
]
*
num_rays
x
=
torch
.
randint
(
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
WIDTH
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
y
=
torch
.
randint
(
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
HEIGHT
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
else
:
else
:
image_id
=
[
index
]
image_id
=
[
index
]
...
...
examples/train_ngp_nerf_occ.py
View file @
3eef9208
...
@@ -24,6 +24,7 @@ from examples.utils import (
...
@@ -24,6 +24,7 @@ from examples.utils import (
)
)
from
nerfacc.estimators.occ_grid
import
OccGridEstimator
from
nerfacc.estimators.occ_grid
import
OccGridEstimator
def
run
(
args
):
def
run
(
args
):
device
=
"cuda:0"
device
=
"cuda:0"
set_random_seed
(
42
)
set_random_seed
(
42
)
...
@@ -102,7 +103,10 @@ def run(args):
...
@@ -102,7 +103,10 @@ def run(args):
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
,
)
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
[
[
...
@@ -167,7 +171,8 @@ def run(args):
...
@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant.
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
=
int
(
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
)
)
train_dataset
.
update_num_rays
(
num_rays
)
train_dataset
.
update_num_rays
(
num_rays
)
...
@@ -249,6 +254,7 @@ def run(args):
...
@@ -249,6 +254,7 @@ def run(args):
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
...
...
nerfacc/estimators/n3tree.py
View file @
3eef9208
...
@@ -2,13 +2,14 @@ import math
...
@@ -2,13 +2,14 @@ import math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
..grid
import
_enlarge_aabb
from
..grid
import
_enlarge_aabb
from
..volrend
import
(
from
..volrend
import
(
render_visibility_from_alpha
,
render_visibility_from_alpha
,
render_visibility_from_density
,
render_visibility_from_density
,
)
)
from
.base
import
AbstractEstimator
from
.base
import
AbstractEstimator
from
torch
import
Tensor
try
:
try
:
import
svox
import
svox
...
@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
)
)
# check the resolution is legal
# check the resolution is legal
assert
isinstance
(
resolution
,
int
),
"N3Tree only supports uniform resolution!"
assert
isinstance
(
resolution
,
int
),
"N3Tree only supports uniform resolution!"
# check the roi_aabb is legal
# check the roi_aabb is legal
if
isinstance
(
roi_aabb
,
(
list
,
tuple
)):
if
isinstance
(
roi_aabb
,
(
list
,
tuple
)):
...
@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
"""
"""
assert
t_min
is
None
and
t_max
is
None
,
(
assert
(
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
t_min
is
None
and
t_max
is
None
)
)
,
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
if
stratified
:
if
stratified
:
near_plane
+=
torch
.
rand
(()).
item
()
*
render_step_size
near_plane
+=
torch
.
rand
(()).
item
()
*
render_step_size
t_starts
,
t_ends
,
packed_info
,
ray_indices
=
svox
.
volume_sample
(
t_starts
,
t_ends
,
packed_info
,
ray_indices
=
svox
.
volume_sample
(
self
.
tree
,
self
.
tree
,
thresh
=
self
.
thresh
,
thresh
=
self
.
thresh
,
rays
=
svox
.
Rays
(
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
()),
rays
=
svox
.
Rays
(
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
()
),
step_size
=
render_step_size
,
step_size
=
render_step_size
,
cone_angle
=
cone_angle
,
cone_angle
=
cone_angle
,
near_plane
=
near_plane
,
near_plane
=
near_plane
,
...
@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tensor
]:
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tensor
]:
"""Samples both n uniform and occupied cells."""
"""Samples both n uniform and occupied cells."""
uniform_indices
=
torch
.
randint
(
len
(
self
.
tree
),
(
n
,),
device
=
self
.
device
)
uniform_indices
=
torch
.
randint
(
occupied_indices
=
torch
.
nonzero
(
self
.
tree
[:].
values
>=
self
.
thresh
)[:,
0
]
len
(
self
.
tree
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
torch
.
nonzero
(
self
.
tree
[:].
values
>=
self
.
thresh
)[
:,
0
]
if
n
<
len
(
occupied_indices
):
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
self
.
device
)
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
occupied_indices
[
selector
]
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
return
indices
...
@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x
=
self
.
tree
.
sample
(
1
).
squeeze
(
1
)
x
=
self
.
tree
.
sample
(
1
).
squeeze
(
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
sel
=
(
*
self
.
tree
.
_all_leaves
().
T
,)
sel
=
(
*
self
.
tree
.
_all_leaves
().
T
,)
self
.
tree
.
data
.
data
[
sel
]
=
torch
.
maximum
(
self
.
tree
.
data
.
data
[
sel
]
*
ema_decay
,
occ
[:,
None
])
self
.
tree
.
data
.
data
[
sel
]
=
torch
.
maximum
(
self
.
tree
.
data
.
data
[
sel
]
*
ema_decay
,
occ
[:,
None
]
)
else
:
else
:
N
=
len
(
self
.
tree
)
//
4
N
=
len
(
self
.
tree
)
//
4
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment