Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
6093da8f
Commit
6093da8f
authored
Sep 10, 2022
by
Ruilong Li
Browse files
better with tips
parent
b5a2af68
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
141 deletions
+117
-141
README.md
README.md
+9
-2
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+83
-47
examples/datasets/utils.py
examples/datasets/utils.py
+0
-85
examples/trainval.py
examples/trainval.py
+25
-7
No files found.
README.md
View file @
6093da8f
...
@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set.
...
@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
\ No newline at end of file
## Tips:
1.
sample rays over all images per iteration (
`batch_over_images=True`
) is better:
`PSNR 33.31 -> 33.75`
.
2.
make use of scheduler (
`MultiStepLR(optimizer, milestones=[20000, 30000], gamma=0.1)`
) to adjust learning rate gives:
`PSNR 33.75 -> 34.40`
.
3.
increasing chunk size (
`chunk: 8192 -> 81920`
) during inference gives speedup:
`FPS 4.x -> 6.2`
examples/datasets/nerf_synthetic.py
View file @
6093da8f
# Copyright (c) Meta Platforms, Inc. and affiliates.
import
collections
import
json
import
json
import
os
import
os
import
imageio.v2
as
imageio
import
imageio.v2
as
imageio
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
from
.utils
import
Cameras
,
generate_rays
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
def
namedtuple_map
(
fn
,
tup
):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return
type
(
tup
)(
*
(
None
if
x
is
None
else
fn
(
x
)
for
x
in
tup
))
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
...
@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
...
@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
camtoworlds
.
append
(
frame
[
"transform_matrix"
])
camtoworlds
.
append
(
frame
[
"transform_matrix"
])
images
.
append
(
rgba
)
images
.
append
(
rgba
)
images
=
np
.
stack
(
images
,
axis
=
0
)
.
astype
(
np
.
float32
)
images
=
np
.
stack
(
images
,
axis
=
0
)
camtoworlds
=
np
.
stack
(
camtoworlds
,
axis
=
0
)
.
astype
(
np
.
float32
)
camtoworlds
=
np
.
stack
(
camtoworlds
,
axis
=
0
)
h
,
w
=
images
.
shape
[
1
:
3
]
h
,
w
=
images
.
shape
[
1
:
3
]
camera_angle_x
=
float
(
meta
[
"camera_angle_x"
])
camera_angle_x
=
float
(
meta
[
"camera_angle_x"
])
...
@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset):
num_rays
:
int
=
None
,
num_rays
:
int
=
None
,
near
:
float
=
None
,
near
:
float
=
None
,
far
:
float
=
None
,
far
:
float
=
None
,
batch_over_images
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
assert
split
in
self
.
SPLITS
,
"%s"
%
split
assert
split
in
self
.
SPLITS
,
"%s"
%
split
...
@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset):
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
,
"trainval"
])
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
,
"trainval"
])
self
.
color_bkgd_aug
=
color_bkgd_aug
self
.
color_bkgd_aug
=
color_bkgd_aug
self
.
batch_over_images
=
batch_over_images
if
split
==
"trainval"
:
if
split
==
"trainval"
:
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
root_fp
,
subject_id
,
"train"
root_fp
,
subject_id
,
"train"
...
@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset):
self
.
images
,
self
.
camtoworlds
,
self
.
focal
=
_load_renderings
(
self
.
images
,
self
.
camtoworlds
,
self
.
focal
=
_load_renderings
(
root_fp
,
subject_id
,
split
root_fp
,
subject_id
,
split
)
)
self
.
images
=
torch
.
from_numpy
(
self
.
images
).
to
(
torch
.
uint8
)
self
.
camtoworlds
=
torch
.
from_numpy
(
self
.
camtoworlds
).
to
(
torch
.
float32
)
self
.
K
=
torch
.
tensor
(
[
[
self
.
focal
,
0
,
self
.
WIDTH
/
2.0
],
[
0
,
self
.
focal
,
self
.
HEIGHT
/
2.0
],
[
0
,
0
,
1
],
],
dtype
=
torch
.
float32
,
)
# (3, 3)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
images
)
return
len
(
self
.
images
)
@
torch
.
no_grad
()
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
data
=
self
.
fetch_data
(
index
)
data
=
self
.
fetch_data
(
index
)
data
=
self
.
preprocess
(
data
)
data
=
self
.
preprocess
(
data
)
...
@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"white"
:
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
elif
self
.
color_bkgd_aug
==
"black"
:
color_bkgd
=
torch
.
zeros
(
3
)
color_bkgd
=
torch
.
zeros
(
3
,
device
=
self
.
images
.
device
)
else
:
else
:
# just use white during inference
# just use white during inference
color_bkgd
=
torch
.
ones
(
3
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
pixels
=
pixels
*
alpha
+
color_bkgd
*
(
1.0
-
alpha
)
pixels
=
pixels
*
alpha
+
color_bkgd
*
(
1.0
-
alpha
)
return
{
return
{
...
@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset):
def
fetch_data
(
self
,
index
):
def
fetch_data
(
self
,
index
):
"""Fetch the data (it maybe cached for multiple batches)."""
"""Fetch the data (it maybe cached for multiple batches)."""
# load data
if
self
.
training
:
camera_id
=
index
if
self
.
batch_over_images
:
K
=
np
.
array
(
image_id
=
torch
.
randint
(
[
0
,
[
self
.
focal
,
0
,
self
.
WIDTH
/
2.0
],
len
(
self
.
images
),
[
0
,
self
.
focal
,
self
.
HEIGHT
/
2.0
],
size
=
(
self
.
num_rays
,),
[
0
,
0
,
1
],
device
=
self
.
images
.
device
,
]
)
).
astype
(
np
.
float32
)
else
:
w2c
=
np
.
linalg
.
inv
(
self
.
camtoworlds
[
camera_id
])
image_id
=
[
index
]
rgba
=
self
.
images
[
camera_id
]
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,),
device
=
self
.
images
.
device
# create pixels
)
rgba
=
torch
.
from_numpy
(
rgba
).
float
()
/
255.0
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
self
.
num_rays
,),
device
=
self
.
images
.
device
# create rays from camera
)
cameras
=
Cameras
(
else
:
intrins
=
torch
.
from_numpy
(
K
).
float
(),
image_id
=
[
index
]
extrins
=
torch
.
from_numpy
(
w2c
).
float
(),
x
,
y
=
torch
.
meshgrid
(
distorts
=
None
,
torch
.
arange
(
self
.
WIDTH
,
device
=
self
.
images
.
device
),
width
=
self
.
WIDTH
,
torch
.
arange
(
self
.
HEIGHT
,
device
=
self
.
images
.
device
),
height
=
self
.
HEIGHT
,
indexing
=
"xy"
,
)
)
x
=
x
.
flatten
()
y
=
y
.
flatten
()
# generate rays
rgba
=
self
.
images
[
image_id
,
y
,
x
]
/
255.0
# (num_rays, 4)
c2w
=
self
.
camtoworlds
[
image_id
]
# (num_rays, 3, 4)
camera_dirs
=
F
.
pad
(
torch
.
stack
(
[
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
],
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
],
],
dim
=-
1
,
),
(
0
,
1
),
value
=
1
,
)
# [num_rays, 3]
camera_dirs
[...,
[
1
,
2
]]
*=
-
1
# opengl format
# [n_cams, height, width, 3]
directions
=
(
camera_dirs
[:,
None
,
:]
*
c2w
[:,
:
3
,
:
3
]).
sum
(
dim
=-
1
)
origins
=
torch
.
broadcast_to
(
c2w
[:,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
if
self
.
num_rays
is
not
None
:
if
self
.
training
:
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,))
origins
=
torch
.
reshape
(
origins
,
(
self
.
num_rays
,
3
))
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
self
.
num_rays
,))
viewdirs
=
torch
.
reshape
(
viewdirs
,
(
self
.
num_rays
,
3
))
pixels_xy
=
torch
.
stack
([
x
,
y
],
dim
=-
1
)
rgba
=
torch
.
reshape
(
rgba
,
(
self
.
num_rays
,
4
))
rgba
=
rgba
[
y
,
x
,
:]
else
:
else
:
pixels_xy
=
None
# full image
origins
=
torch
.
reshape
(
origins
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
3
))
viewdirs
=
torch
.
reshape
(
viewdirs
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
3
))
# Be careful: This dataset's camera coordinate is not the same as
rgba
=
torch
.
reshape
(
rgba
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
4
))
# opencv's camera coordinate! It is actually opengl.
rays
=
generate_rays
(
rays
=
Rays
(
origins
=
origins
,
viewdirs
=
viewdirs
)
cameras
,
opencv_format
=
False
,
pixels_xy
=
pixels_xy
,
)
return
{
return
{
"camera_id"
:
camera_id
,
"rgba"
:
rgba
,
# [h, w, 4] or [num_rays, 4]
"rgba"
:
rgba
,
# [h, w, 4] or [num_rays, 4]
"rays"
:
rays
,
# [h, w] or [num_rays,
4
]
"rays"
:
rays
,
# [h, w
, 3
] or [num_rays,
3
]
}
}
examples/datasets/utils.py
deleted
100644 → 0
View file @
b5a2af68
# Copyright (c) Meta Platforms, Inc. and affiliates.
import
collections
import
math
import
torch
import
torch.nn.functional
as
F
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
Cameras
=
collections
.
namedtuple
(
"Cameras"
,
(
"intrins"
,
"extrins"
,
"distorts"
,
"width"
,
"height"
)
)
def
namedtuple_map
(
fn
,
tup
):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return
type
(
tup
)(
*
(
None
if
x
is
None
else
fn
(
x
)
for
x
in
tup
))
def
homo
(
points
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Get the homogeneous coordinates."""
return
F
.
pad
(
points
,
(
0
,
1
),
value
=
1
)
def
transform_cameras
(
cameras
:
Cameras
,
resize_factor
:
float
)
->
torch
.
Tensor
:
intrins
=
cameras
.
intrins
intrins
[...,
:
2
,
:]
=
intrins
[...,
:
2
,
:]
*
resize_factor
width
=
int
(
cameras
.
width
*
resize_factor
+
0.5
)
height
=
int
(
cameras
.
height
*
resize_factor
+
0.5
)
return
Cameras
(
intrins
=
intrins
,
extrins
=
cameras
.
extrins
,
distorts
=
cameras
.
distorts
,
width
=
width
,
height
=
height
,
)
def
generate_rays
(
cameras
:
Cameras
,
opencv_format
:
bool
=
True
,
pixels_xy
:
torch
.
Tensor
=
None
,
)
->
Rays
:
"""Generating rays for a single or multiple cameras.
:params cameras [(n_cams,)]
:returns: Rays
[(n_cams,) height, width] if pixels_xy is None
[(n_cams,) num_pixels] if pixels_xy is given
"""
if
pixels_xy
is
not
None
:
K
=
cameras
.
intrins
[...,
None
,
:,
:]
c2w
=
cameras
.
extrins
[...,
None
,
:,
:].
inverse
()
x
,
y
=
pixels_xy
[...,
0
],
pixels_xy
[...,
1
]
else
:
K
=
cameras
.
intrins
[...,
None
,
None
,
:,
:]
c2w
=
cameras
.
extrins
[...,
None
,
None
,
:,
:].
inverse
()
x
,
y
=
torch
.
meshgrid
(
torch
.
arange
(
cameras
.
width
,
dtype
=
K
.
dtype
),
torch
.
arange
(
cameras
.
height
,
dtype
=
K
.
dtype
),
indexing
=
"xy"
,
)
# [height, width]
camera_dirs
=
homo
(
torch
.
stack
(
[
(
x
-
K
[...,
0
,
2
]
+
0.5
)
/
K
[...,
0
,
0
],
(
y
-
K
[...,
1
,
2
]
+
0.5
)
/
K
[...,
1
,
1
],
],
dim
=-
1
,
)
)
# [n_cams, height, width, 3]
if
not
opencv_format
:
camera_dirs
[...,
[
1
,
2
]]
*=
-
1
# [n_cams, height, width, 3]
directions
=
(
camera_dirs
[...,
None
,
:]
*
c2w
[...,
:
3
,
:
3
]).
sum
(
dim
=-
1
)
origins
=
torch
.
broadcast_to
(
c2w
[...,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
rays
=
Rays
(
origins
=
origins
,
# [n_cams, height, width, 3]
viewdirs
=
viewdirs
,
# [n_cams, height, width, 3]
)
return
rays
examples/trainval.py
View file @
6093da8f
...
@@ -5,8 +5,7 @@ import numpy as np
...
@@ -5,8 +5,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
tqdm
import
tqdm
from
datasets.nerf_synthetic
import
SubjectLoader
from
datasets.nerf_synthetic
import
SubjectLoader
,
namedtuple_map
from
datasets.utils
import
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
radiance_fields.ngp
import
NGPradianceField
from
nerfacc
import
OccupancyField
,
volumetric_rendering
from
nerfacc
import
OccupancyField
,
volumetric_rendering
...
@@ -67,19 +66,26 @@ if __name__ == "__main__":
...
@@ -67,19 +66,26 @@ if __name__ == "__main__":
split
=
"train"
,
split
=
"train"
,
num_rays
=
8192
,
num_rays
=
8192
,
)
)
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
# train_dataset.K = train_dataset.K.to(device)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
train_dataset
,
num_workers
=
1
,
num_workers
=
4
,
batch_size
=
None
,
batch_size
=
None
,
persistent_workers
=
True
,
persistent_workers
=
True
,
shuffle
=
True
,
shuffle
=
True
,
)
)
test_dataset
=
SubjectLoader
(
test_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
subject_id
=
"lego"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"test"
,
split
=
"test"
,
num_rays
=
None
,
num_rays
=
None
,
)
)
# test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
test_dataset
,
num_workers
=
4
,
num_workers
=
4
,
...
@@ -102,9 +108,9 @@ if __name__ == "__main__":
...
@@ -102,9 +108,9 @@ if __name__ == "__main__":
)
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
#
scheduler = torch.optim.lr_scheduler.MultiStepLR(
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
#
optimizer, milestones=[
10000,
20000, 30000], gamma=0.1
optimizer
,
milestones
=
[
20000
,
30000
],
gamma
=
0.1
#
)
)
# setup occupancy field with eval function
# setup occupancy field with eval function
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -156,7 +162,7 @@ if __name__ == "__main__":
...
@@ -156,7 +162,7 @@ if __name__ == "__main__":
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
#
scheduler.step()
scheduler
.
step
()
if
step
%
50
==
0
:
if
step
%
50
==
0
:
elapsed_time
=
time
.
time
()
-
tic
elapsed_time
=
time
.
time
()
-
tic
...
@@ -189,6 +195,18 @@ if __name__ == "__main__":
...
@@ -189,6 +195,18 @@ if __name__ == "__main__":
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment