Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
3eef9208
Commit
3eef9208
authored
Jan 02, 2024
by
Ruilong Li
Browse files
format
parent
88a6aec6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
20 deletions
+59
-20
examples/datasets/nerf_360_v2.py
examples/datasets/nerf_360_v2.py
+13
-3
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+13
-3
examples/train_ngp_nerf_occ.py
examples/train_ngp_nerf_occ.py
+9
-3
nerfacc/estimators/n3tree.py
nerfacc/estimators/n3tree.py
+24
-11
No files found.
examples/datasets/nerf_360_v2.py
View file @
3eef9208
...
@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
elif
self
.
color_bkgd_aug
==
"white"
:
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
elif
self
.
color_bkgd_aug
==
"black"
:
...
@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else
:
else
:
image_id
=
[
index
]
*
num_rays
image_id
=
[
index
]
*
num_rays
x
=
torch
.
randint
(
x
=
torch
.
randint
(
0
,
self
.
width
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
width
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
y
=
torch
.
randint
(
y
=
torch
.
randint
(
0
,
self
.
height
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
height
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
else
:
else
:
image_id
=
[
index
]
image_id
=
[
index
]
...
...
examples/datasets/nerf_synthetic.py
View file @
3eef9208
...
@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
,
generator
=
self
.
g
)
elif
self
.
color_bkgd_aug
==
"white"
:
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
elif
self
.
color_bkgd_aug
==
"black"
:
...
@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else
:
else
:
image_id
=
[
index
]
*
num_rays
image_id
=
[
index
]
*
num_rays
x
=
torch
.
randint
(
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
WIDTH
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
y
=
torch
.
randint
(
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
0
,
self
.
HEIGHT
,
size
=
(
num_rays
,),
device
=
self
.
images
.
device
,
generator
=
self
.
g
,
)
)
else
:
else
:
image_id
=
[
index
]
image_id
=
[
index
]
...
...
examples/train_ngp_nerf_occ.py
View file @
3eef9208
...
@@ -24,6 +24,7 @@ from examples.utils import (
...
@@ -24,6 +24,7 @@ from examples.utils import (
)
)
from
nerfacc.estimators.occ_grid
import
OccGridEstimator
from
nerfacc.estimators.occ_grid
import
OccGridEstimator
def
run
(
args
):
def
run
(
args
):
device
=
"cuda:0"
device
=
"cuda:0"
set_random_seed
(
42
)
set_random_seed
(
42
)
...
@@ -102,7 +103,10 @@ def run(args):
...
@@ -102,7 +103,10 @@ def run(args):
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
,
)
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
[
[
...
@@ -167,7 +171,8 @@ def run(args):
...
@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant.
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
=
int
(
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
)
)
train_dataset
.
update_num_rays
(
num_rays
)
train_dataset
.
update_num_rays
(
num_rays
)
...
@@ -249,6 +254,7 @@ def run(args):
...
@@ -249,6 +254,7 @@ def run(args):
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -274,4 +280,4 @@ if __name__ == "__main__":
...
@@ -274,4 +280,4 @@ if __name__ == "__main__":
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
run
(
args
)
run
(
args
)
\ No newline at end of file
nerfacc/estimators/n3tree.py
View file @
3eef9208
...
@@ -2,13 +2,14 @@ import math
...
@@ -2,13 +2,14 @@ import math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
..grid
import
_enlarge_aabb
from
..grid
import
_enlarge_aabb
from
..volrend
import
(
from
..volrend
import
(
render_visibility_from_alpha
,
render_visibility_from_alpha
,
render_visibility_from_density
,
render_visibility_from_density
,
)
)
from
.base
import
AbstractEstimator
from
.base
import
AbstractEstimator
from
torch
import
Tensor
try
:
try
:
import
svox
import
svox
...
@@ -21,7 +22,7 @@ except ImportError:
...
@@ -21,7 +22,7 @@ except ImportError:
class
N3TreeEstimator
(
AbstractEstimator
):
class
N3TreeEstimator
(
AbstractEstimator
):
"""Use N3Tree to implement Occupancy Grid.
"""Use N3Tree to implement Occupancy Grid.
This allows more flexible topologies than the cascaded grid. However, it is
This allows more flexible topologies than the cascaded grid. However, it is
slower to create samples from the tree than the cascaded grid. By default,
slower to create samples from the tree than the cascaded grid. By default,
it has the same topology as the cascaded grid but `self.tree` can be
it has the same topology as the cascaded grid but `self.tree` can be
...
@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
)
)
# check the resolution is legal
# check the resolution is legal
assert
isinstance
(
resolution
,
int
),
"N3Tree only supports uniform resolution!"
assert
isinstance
(
resolution
,
int
),
"N3Tree only supports uniform resolution!"
# check the roi_aabb is legal
# check the roi_aabb is legal
if
isinstance
(
roi_aabb
,
(
list
,
tuple
)):
if
isinstance
(
roi_aabb
,
(
list
,
tuple
)):
...
@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
"""
"""
assert
t_min
is
None
and
t_max
is
None
,
(
assert
(
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
t_min
is
None
and
t_max
is
None
)
)
,
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
if
stratified
:
if
stratified
:
near_plane
+=
torch
.
rand
(()).
item
()
*
render_step_size
near_plane
+=
torch
.
rand
(()).
item
()
*
render_step_size
t_starts
,
t_ends
,
packed_info
,
ray_indices
=
svox
.
volume_sample
(
t_starts
,
t_ends
,
packed_info
,
ray_indices
=
svox
.
volume_sample
(
self
.
tree
,
self
.
tree
,
thresh
=
self
.
thresh
,
thresh
=
self
.
thresh
,
rays
=
svox
.
Rays
(
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
()),
rays
=
svox
.
Rays
(
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
()
),
step_size
=
render_step_size
,
step_size
=
render_step_size
,
cone_angle
=
cone_angle
,
cone_angle
=
cone_angle
,
near_plane
=
near_plane
,
near_plane
=
near_plane
,
...
@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tensor
]:
def
_sample_uniform_and_occupied_cells
(
self
,
n
:
int
)
->
List
[
Tensor
]:
"""Samples both n uniform and occupied cells."""
"""Samples both n uniform and occupied cells."""
uniform_indices
=
torch
.
randint
(
len
(
self
.
tree
),
(
n
,),
device
=
self
.
device
)
uniform_indices
=
torch
.
randint
(
occupied_indices
=
torch
.
nonzero
(
self
.
tree
[:].
values
>=
self
.
thresh
)[:,
0
]
len
(
self
.
tree
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
torch
.
nonzero
(
self
.
tree
[:].
values
>=
self
.
thresh
)[
:,
0
]
if
n
<
len
(
occupied_indices
):
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
self
.
device
)
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
self
.
device
)
occupied_indices
=
occupied_indices
[
selector
]
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
return
indices
...
@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
...
@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x
=
self
.
tree
.
sample
(
1
).
squeeze
(
1
)
x
=
self
.
tree
.
sample
(
1
).
squeeze
(
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
occ
=
occ_eval_fn
(
x
).
squeeze
(
-
1
)
sel
=
(
*
self
.
tree
.
_all_leaves
().
T
,)
sel
=
(
*
self
.
tree
.
_all_leaves
().
T
,)
self
.
tree
.
data
.
data
[
sel
]
=
torch
.
maximum
(
self
.
tree
.
data
.
data
[
sel
]
*
ema_decay
,
occ
[:,
None
])
self
.
tree
.
data
.
data
[
sel
]
=
torch
.
maximum
(
self
.
tree
.
data
.
data
[
sel
]
*
ema_decay
,
occ
[:,
None
]
)
else
:
else
:
N
=
len
(
self
.
tree
)
//
4
N
=
len
(
self
.
tree
)
//
4
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment