Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
9f90842b
Commit
9f90842b
authored
Dec 15, 2023
by
Ruilong Li
Browse files
cleanup training code
parent
e6647a00
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
267 additions
and
248 deletions
+267
-248
examples/train_ngp_nerf_occ.py
examples/train_ngp_nerf_occ.py
+228
-213
examples/utils.py
examples/utils.py
+39
-35
No files found.
examples/train_ngp_nerf_occ.py
View file @
9f90842b
...
...
@@ -24,239 +24,254 @@ from examples.utils import (
)
from
nerfacc.estimators.occ_grid
import
OccGridEstimator
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data_root"
,
type
=
str
,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default
=
str
(
pathlib
.
Path
.
cwd
()
/
"data/nerf_synthetic"
),
help
=
"the root dir of the dataset"
,
)
parser
.
add_argument
(
"--train_split"
,
type
=
str
,
default
=
"train"
,
choices
=
[
"train"
,
"trainval"
],
help
=
"which train split to use"
,
)
parser
.
add_argument
(
"--scene"
,
type
=
str
,
default
=
"lego"
,
choices
=
NERF_SYNTHETIC_SCENES
+
MIPNERF360_UNBOUNDED_SCENES
,
help
=
"which scene to use"
,
)
args
=
parser
.
parse_args
()
def
run
(
args
):
device
=
"cuda:0"
set_random_seed
(
42
)
device
=
"cuda:0"
set_random_seed
(
42
)
if
args
.
scene
in
MIPNERF360_UNBOUNDED_SCENES
:
from
datasets.nerf_360_v2
import
SubjectLoader
if
args
.
scene
in
MIPNERF360_UNBOUNDED_SCENES
:
from
datasets.nerf_360_v2
import
SubjectLoader
# training parameters
max_steps
=
20000
init_batch_size
=
1024
target_sample_batch_size
=
1
<<
18
weight_decay
=
0.0
# scene parameters
aabb
=
torch
.
tensor
([
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
1.0
,
1.0
],
device
=
device
)
near_plane
=
0.2
far_plane
=
1.0e10
# dataset parameters
train_dataset_kwargs
=
{
"color_bkgd_aug"
:
"random"
,
"factor"
:
4
}
test_dataset_kwargs
=
{
"factor"
:
4
}
# model parameters
grid_resolution
=
128
grid_nlvl
=
4
# render parameters
render_step_size
=
1e-3
alpha_thre
=
1e-2
cone_angle
=
0.004
# training parameters
max_steps
=
20000
init_batch_size
=
1024
target_sample_batch_size
=
1
<<
18
weight_decay
=
0.0
# scene parameters
aabb
=
torch
.
tensor
([
-
1.0
,
-
1.0
,
-
1.0
,
1.0
,
1.0
,
1.0
],
device
=
device
)
near_plane
=
0.2
far_plane
=
1.0e10
# dataset parameters
train_dataset_kwargs
=
{
"color_bkgd_aug"
:
"random"
,
"factor"
:
4
}
test_dataset_kwargs
=
{
"factor"
:
4
}
# model parameters
grid_resolution
=
128
grid_nlvl
=
4
# render parameters
render_step_size
=
1e-3
alpha_thre
=
1e-2
cone_angle
=
0.004
else
:
from
datasets.nerf_synthetic
import
SubjectLoader
else
:
from
datasets.nerf_synthetic
import
SubjectLoader
# training parameters
max_steps
=
20000
init_batch_size
=
1024
target_sample_batch_size
=
1
<<
18
weight_decay
=
(
1e-5
if
args
.
scene
in
[
"materials"
,
"ficus"
,
"drums"
]
else
1e-6
)
# scene parameters
aabb
=
torch
.
tensor
([
-
1.5
,
-
1.5
,
-
1.5
,
1.5
,
1.5
,
1.5
],
device
=
device
)
near_plane
=
0.0
far_plane
=
1.0e10
# dataset parameters
train_dataset_kwargs
=
{}
test_dataset_kwargs
=
{}
# model parameters
grid_resolution
=
128
grid_nlvl
=
1
# render parameters
render_step_size
=
5e-3
alpha_thre
=
0.0
cone_angle
=
0.0
# training parameters
max_steps
=
20000
init_batch_size
=
1024
target_sample_batch_size
=
1
<<
18
weight_decay
=
(
1e-5
if
args
.
scene
in
[
"materials"
,
"ficus"
,
"drums"
]
else
1e-6
train_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
args
.
data_root
,
split
=
args
.
train_split
,
num_rays
=
init_batch_size
,
device
=
device
,
**
train_dataset_kwargs
,
)
# scene parameters
aabb
=
torch
.
tensor
([
-
1.5
,
-
1.5
,
-
1.5
,
1.5
,
1.5
,
1.5
],
device
=
device
)
near_plane
=
0.0
far_plane
=
1.0e10
# dataset parameters
train_dataset_kwargs
=
{}
test_dataset_kwargs
=
{}
# model parameters
grid_resolution
=
128
grid_nlvl
=
1
# render parameters
render_step_size
=
5e-3
alpha_thre
=
0.0
cone_angle
=
0.0
train
_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
args
.
data_root
,
split
=
args
.
train_split
,
num_rays
=
init_batch_siz
e
,
device
=
device
,
**
t
rain
_dataset_kwargs
,
)
test
_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
args
.
data_root
,
split
=
"test"
,
num_rays
=
Non
e
,
device
=
device
,
**
t
est
_dataset_kwargs
,
)
test_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
args
.
data_root
,
split
=
"test"
,
num_rays
=
None
,
device
=
device
,
**
test_dataset_kwargs
,
)
estimator
=
OccGridEstimator
(
roi_aabb
=
aabb
,
resolution
=
grid_resolution
,
levels
=
grid_nlvl
).
to
(
device
)
estimator
=
OccGridEstimator
(
roi_aabb
=
aabb
,
resolution
=
grid_resolution
,
levels
=
grid_nlvl
).
to
(
device
)
# setup the radiance field we want to train.
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
[
torch
.
optim
.
lr_scheduler
.
LinearLR
(
optimizer
,
start_factor
=
0.01
,
total_iters
=
100
),
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
max_steps
//
2
,
max_steps
*
3
//
4
,
max_steps
*
9
//
10
,
],
gamma
=
0.33
,
),
]
)
lpips_net
=
LPIPS
(
net
=
"vgg"
).
to
(
device
)
lpips_norm_fn
=
lambda
x
:
x
[
None
,
...].
permute
(
0
,
3
,
1
,
2
)
*
2
-
1
lpips_fn
=
lambda
x
,
y
:
lpips_net
(
lpips_norm_fn
(
x
),
lpips_norm_fn
(
y
)).
mean
()
# setup the radiance field we want to train.
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
radiance_field
=
NGPRadianceField
(
aabb
=
estimator
.
aabbs
[
-
1
]).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
1e-2
,
eps
=
1e-15
,
weight_decay
=
weight_decay
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ChainedScheduler
(
[
torch
.
optim
.
lr_scheduler
.
LinearLR
(
optimizer
,
start_factor
=
0.01
,
total_iters
=
100
),
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
max_steps
//
2
,
max_steps
*
3
//
4
,
max_steps
*
9
//
10
,
],
gamma
=
0.33
,
),
]
)
lpips_net
=
LPIPS
(
net
=
"vgg"
).
to
(
device
)
lpips_norm_fn
=
lambda
x
:
x
[
None
,
...].
permute
(
0
,
3
,
1
,
2
)
*
2
-
1
lpips_fn
=
lambda
x
,
y
:
lpips_net
(
lpips_norm_fn
(
x
),
lpips_norm_fn
(
y
)).
mean
()
# training
tic
=
time
.
time
()
for
step
in
range
(
max_steps
+
1
):
radiance_field
.
train
()
estimator
.
train
()
# training
tic
=
time
.
time
()
for
step
in
range
(
max_steps
+
1
):
radiance_field
.
train
()
estimator
.
train
()
i
=
torch
.
randint
(
0
,
len
(
train_dataset
),
(
1
,)).
item
()
data
=
train_dataset
[
i
]
i
=
torch
.
randint
(
0
,
len
(
train_dataset
),
(
1
,)).
item
()
data
=
train_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
def
occ_eval_fn
(
x
):
density
=
radiance_field
.
query_density
(
x
)
return
density
*
render_step_size
def
occ_eval_fn
(
x
):
density
=
radiance_field
.
query_density
(
x
)
return
density
*
render_step_size
# update occupancy grid
estimator
.
update_every_n_steps
(
step
=
step
,
occ_eval_fn
=
occ_eval_fn
,
occ_thre
=
1e-2
,
)
# update occupancy grid
estimator
.
update_every_n_steps
(
step
=
step
,
occ_eval_fn
=
occ_eval_fn
,
occ_thre
=
1e-2
,
)
# render
rgb
,
acc
,
depth
,
n_rendering_samples
=
render_image_with_occgrid
(
radiance_field
,
estimator
,
rays
,
# rendering options
near_plane
=
near_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
)
if
n_rendering_samples
==
0
:
continue
# render
rgb
,
acc
,
depth
,
n_rendering_samples
=
render_image_with_occgrid
(
radiance_field
,
estimator
,
rays
,
# rendering options
near_plane
=
near_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
)
if
n_rendering_samples
==
0
:
continue
if
target_sample_batch_size
>
0
:
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
)
train_dataset
.
update_num_rays
(
num_rays
)
if
target_sample_batch_size
>
0
:
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
)
train_dataset
.
update_num_rays
(
num_rays
)
# compute loss
loss
=
F
.
smooth_l1_loss
(
rgb
,
pixels
)
# compute loss
loss
=
F
.
smooth_l1_loss
(
rgb
,
pixels
)
optimizer
.
zero_grad
()
# do not unscale it because we are using Adam.
grad_scaler
.
scale
(
loss
).
backward
()
optimizer
.
step
()
scheduler
.
step
()
optimizer
.
zero_grad
()
# do not unscale it because we are using Adam.
grad_scaler
.
scale
(
loss
).
backward
()
optimizer
.
step
()
scheduler
.
step
()
if
step
%
10000
==
0
:
elapsed_time
=
time
.
time
()
-
tic
loss
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
loss
)
/
np
.
log
(
10.0
)
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s | step=
{
step
}
| "
f
"loss=
{
loss
:.
5
f
}
| psnr=
{
psnr
:.
2
f
}
| "
f
"n_rendering_samples=
{
n_rendering_samples
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
| "
f
"max_depth=
{
depth
.
max
():.
3
f
}
| "
)
if
step
%
10000
==
0
:
elapsed_time
=
time
.
time
()
-
tic
loss
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
loss
)
/
np
.
log
(
10.0
)
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s | step=
{
step
}
| "
f
"loss=
{
loss
:.
5
f
}
| psnr=
{
psnr
:.
2
f
}
| "
f
"n_rendering_samples=
{
n_rendering_samples
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
| "
f
"max_depth=
{
depth
.
max
():.
3
f
}
| "
)
if
step
>
0
and
step
%
max_steps
==
0
:
# evaluation
radiance_field
.
eval
()
estimator
.
eval
()
psnrs
=
[]
lpips
=
[]
with
torch
.
no_grad
():
for
i
in
tqdm
.
tqdm
(
range
(
len
(
test_dataset
))):
data
=
test_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
if
step
>
0
and
step
%
max_steps
==
0
:
# evaluation
radiance_field
.
eval
()
estimator
.
eval
()
# rendering
# rgb, acc, depth, _ = render_image_with_occgrid_test(
# 1024,
# # scene
# radiance_field,
# estimator,
# rays,
# # rendering options
# near_plane=near_plane,
# render_step_size=render_step_size,
# render_bkgd=render_bkgd,
# cone_angle=cone_angle,
# alpha_thre=alpha_thre,
# )
rgb
,
acc
,
depth
,
_
=
render_image_with_occgrid
(
radiance_field
,
estimator
,
rays
,
# rendering options
near_plane
=
near_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
)
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
psnrs
.
append
(
psnr
.
item
())
lpips
.
append
(
lpips_fn
(
rgb
,
pixels
).
item
())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
psnrs
=
[]
lpips
=
[]
with
torch
.
no_grad
():
for
i
in
tqdm
.
tqdm
(
range
(
len
(
test_dataset
))):
data
=
test_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data_root"
,
type
=
str
,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default
=
str
(
pathlib
.
Path
.
cwd
()
/
"data/nerf_synthetic"
),
help
=
"the root dir of the dataset"
,
)
parser
.
add_argument
(
"--train_split"
,
type
=
str
,
default
=
"train"
,
choices
=
[
"train"
,
"trainval"
],
help
=
"which train split to use"
,
)
parser
.
add_argument
(
"--scene"
,
type
=
str
,
default
=
"lego"
,
choices
=
NERF_SYNTHETIC_SCENES
+
MIPNERF360_UNBOUNDED_SCENES
,
help
=
"which scene to use"
,
)
args
=
parser
.
parse_args
()
# rendering
rgb
,
acc
,
depth
,
_
=
render_image_with_occgrid_test
(
1024
,
# scene
radiance_field
,
estimator
,
rays
,
# rendering options
near_plane
=
near_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
)
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
psnrs
.
append
(
psnr
.
item
())
lpips
.
append
(
lpips_fn
(
rgb
,
pixels
).
item
())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
lpips_avg
=
sum
(
lpips
)
/
len
(
lpips
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
, lpips_avg=
{
lpips_avg
}
"
)
run
(
args
)
\ No newline at end of file
examples/utils.py
View file @
9f90842b
...
...
@@ -79,38 +79,6 @@ def render_image_with_occgrid(
else
:
num_rays
,
_
=
rays_shape
def
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
):
t_origins
=
chunk_rays
.
origins
[
ray_indices
]
t_dirs
=
chunk_rays
.
viewdirs
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)[:,
None
]
/
2.0
if
timestamps
is
not
None
:
# dnerf
t
=
(
timestamps
[
ray_indices
]
if
radiance_field
.
training
else
timestamps
.
expand_as
(
positions
[:,
:
1
])
)
sigmas
=
radiance_field
.
query_density
(
positions
,
t
)
else
:
sigmas
=
radiance_field
.
query_density
(
positions
)
return
sigmas
.
squeeze
(
-
1
)
def
rgb_sigma_fn
(
t_starts
,
t_ends
,
ray_indices
):
t_origins
=
chunk_rays
.
origins
[
ray_indices
]
t_dirs
=
chunk_rays
.
viewdirs
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)[:,
None
]
/
2.0
if
timestamps
is
not
None
:
# dnerf
t
=
(
timestamps
[
ray_indices
]
if
radiance_field
.
training
else
timestamps
.
expand_as
(
positions
[:,
:
1
])
)
rgbs
,
sigmas
=
radiance_field
(
positions
,
t
,
t_dirs
)
else
:
rgbs
,
sigmas
=
radiance_field
(
positions
,
t_dirs
)
return
rgbs
,
sigmas
.
squeeze
(
-
1
)
results
=
[]
chunk
=
(
torch
.
iinfo
(
torch
.
int32
).
max
...
...
@@ -119,9 +87,45 @@ def render_image_with_occgrid(
)
for
i
in
range
(
0
,
num_rays
,
chunk
):
chunk_rays
=
namedtuple_map
(
lambda
r
:
r
[
i
:
i
+
chunk
],
rays
)
rays_o
=
chunk_rays
.
origins
rays_d
=
chunk_rays
.
viewdirs
def
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
):
t_origins
=
rays_o
[
ray_indices
]
t_dirs
=
rays_d
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)[:,
None
]
/
2.0
if
timestamps
is
not
None
:
# dnerf
t
=
(
timestamps
[
ray_indices
]
if
radiance_field
.
training
else
timestamps
.
expand_as
(
positions
[:,
:
1
])
)
sigmas
=
radiance_field
.
query_density
(
positions
,
t
)
else
:
sigmas
=
radiance_field
.
query_density
(
positions
)
return
sigmas
.
squeeze
(
-
1
)
def
rgb_sigma_fn
(
t_starts
,
t_ends
,
ray_indices
):
t_origins
=
rays_o
[
ray_indices
]
t_dirs
=
rays_d
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)[:,
None
]
/
2.0
if
timestamps
is
not
None
:
# dnerf
t
=
(
timestamps
[
ray_indices
]
if
radiance_field
.
training
else
timestamps
.
expand_as
(
positions
[:,
:
1
])
)
rgbs
,
sigmas
=
radiance_field
(
positions
,
t
,
t_dirs
)
else
:
rgbs
,
sigmas
=
radiance_field
(
positions
,
t_dirs
)
return
rgbs
,
sigmas
.
squeeze
(
-
1
)
ray_indices
,
t_starts
,
t_ends
=
estimator
.
sampling
(
chunk_rays
.
origins
,
chunk_rays
.
viewdirs
,
rays_o
,
rays_d
,
sigma_fn
=
sigma_fn
,
near_plane
=
near_plane
,
far_plane
=
far_plane
,
...
...
@@ -134,7 +138,7 @@ def render_image_with_occgrid(
t_starts
,
t_ends
,
ray_indices
,
n_rays
=
chunk_rays
.
origins
.
shape
[
0
],
n_rays
=
rays_o
.
shape
[
0
],
rgb_sigma_fn
=
rgb_sigma_fn
,
render_bkgd
=
render_bkgd
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment