Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
6093da8f
Commit
6093da8f
authored
Sep 10, 2022
by
Ruilong Li
Browse files
better with tips
parent
b5a2af68
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
141 deletions
+117
-141
README.md
README.md
+9
-2
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+83
-47
examples/datasets/utils.py
examples/datasets/utils.py
+0
-85
examples/trainval.py
examples/trainval.py
+25
-7
No files found.
README.md
View file @
6093da8f
...
...
@@ -16,5 +16,12 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
## Tips:
1.
sample rays over all images per iteration (
`batch_over_images=True`
) is better:
`PSNR 33.31 -> 33.75`
.
2.
make use of scheduler (
`MultiStepLR(optimizer, milestones=[20000, 30000], gamma=0.1)`
) to adjust learning rate gives:
`PSNR 33.75 -> 34.40`
.
3.
increasing chunk size (
`chunk: 8192 -> 81920`
) during inference gives speedup:
`FPS 4.x -> 6.2`
examples/datasets/nerf_synthetic.py
View file @
6093da8f
# Copyright (c) Meta Platforms, Inc. and affiliates.
import
collections
import
json
import
os
import
imageio.v2
as
imageio
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
.utils
import
Cameras
,
generate_rays
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
def
namedtuple_map
(
fn
,
tup
):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return
type
(
tup
)(
*
(
None
if
x
is
None
else
fn
(
x
)
for
x
in
tup
))
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
...
...
@@ -33,8 +39,8 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
camtoworlds
.
append
(
frame
[
"transform_matrix"
])
images
.
append
(
rgba
)
images
=
np
.
stack
(
images
,
axis
=
0
)
.
astype
(
np
.
float32
)
camtoworlds
=
np
.
stack
(
camtoworlds
,
axis
=
0
)
.
astype
(
np
.
float32
)
images
=
np
.
stack
(
images
,
axis
=
0
)
camtoworlds
=
np
.
stack
(
camtoworlds
,
axis
=
0
)
h
,
w
=
images
.
shape
[
1
:
3
]
camera_angle_x
=
float
(
meta
[
"camera_angle_x"
])
...
...
@@ -69,6 +75,7 @@ class SubjectLoader(torch.utils.data.Dataset):
num_rays
:
int
=
None
,
near
:
float
=
None
,
far
:
float
=
None
,
batch_over_images
:
bool
=
True
,
):
super
().
__init__
()
assert
split
in
self
.
SPLITS
,
"%s"
%
split
...
...
@@ -80,6 +87,7 @@ class SubjectLoader(torch.utils.data.Dataset):
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
,
"trainval"
])
self
.
color_bkgd_aug
=
color_bkgd_aug
self
.
batch_over_images
=
batch_over_images
if
split
==
"trainval"
:
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
root_fp
,
subject_id
,
"train"
...
...
@@ -94,11 +102,22 @@ class SubjectLoader(torch.utils.data.Dataset):
self
.
images
,
self
.
camtoworlds
,
self
.
focal
=
_load_renderings
(
root_fp
,
subject_id
,
split
)
self
.
images
=
torch
.
from_numpy
(
self
.
images
).
to
(
torch
.
uint8
)
self
.
camtoworlds
=
torch
.
from_numpy
(
self
.
camtoworlds
).
to
(
torch
.
float32
)
self
.
K
=
torch
.
tensor
(
[
[
self
.
focal
,
0
,
self
.
WIDTH
/
2.0
],
[
0
,
self
.
focal
,
self
.
HEIGHT
/
2.0
],
[
0
,
0
,
1
],
],
dtype
=
torch
.
float32
,
)
# (3, 3)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
def
__len__
(
self
):
return
len
(
self
.
images
)
@
torch
.
no_grad
()
def
__getitem__
(
self
,
index
):
data
=
self
.
fetch_data
(
index
)
data
=
self
.
preprocess
(
data
)
...
...
@@ -111,14 +130,14 @@ class SubjectLoader(torch.utils.data.Dataset):
if
self
.
training
:
if
self
.
color_bkgd_aug
==
"random"
:
color_bkgd
=
torch
.
rand
(
3
)
color_bkgd
=
torch
.
rand
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"white"
:
color_bkgd
=
torch
.
ones
(
3
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
elif
self
.
color_bkgd_aug
==
"black"
:
color_bkgd
=
torch
.
zeros
(
3
)
color_bkgd
=
torch
.
zeros
(
3
,
device
=
self
.
images
.
device
)
else
:
# just use white during inference
color_bkgd
=
torch
.
ones
(
3
)
color_bkgd
=
torch
.
ones
(
3
,
device
=
self
.
images
.
device
)
pixels
=
pixels
*
alpha
+
color_bkgd
*
(
1.0
-
alpha
)
return
{
...
...
@@ -130,48 +149,65 @@ class SubjectLoader(torch.utils.data.Dataset):
def
fetch_data
(
self
,
index
):
"""Fetch the data (it maybe cached for multiple batches)."""
# load data
camera_id
=
index
K
=
np
.
array
(
[
[
self
.
focal
,
0
,
self
.
WIDTH
/
2.0
],
[
0
,
self
.
focal
,
self
.
HEIGHT
/
2.0
],
[
0
,
0
,
1
],
]
).
astype
(
np
.
float32
)
w2c
=
np
.
linalg
.
inv
(
self
.
camtoworlds
[
camera_id
])
rgba
=
self
.
images
[
camera_id
]
# create pixels
rgba
=
torch
.
from_numpy
(
rgba
).
float
()
/
255.0
# create rays from camera
cameras
=
Cameras
(
intrins
=
torch
.
from_numpy
(
K
).
float
(),
extrins
=
torch
.
from_numpy
(
w2c
).
float
(),
distorts
=
None
,
width
=
self
.
WIDTH
,
height
=
self
.
HEIGHT
,
if
self
.
training
:
if
self
.
batch_over_images
:
image_id
=
torch
.
randint
(
0
,
len
(
self
.
images
),
size
=
(
self
.
num_rays
,),
device
=
self
.
images
.
device
,
)
else
:
image_id
=
[
index
]
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,),
device
=
self
.
images
.
device
)
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
self
.
num_rays
,),
device
=
self
.
images
.
device
)
if
self
.
num_rays
is
not
None
:
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,))
y
=
torch
.
randint
(
0
,
self
.
HEIGHT
,
size
=
(
self
.
num_rays
,))
pixels_xy
=
torch
.
stack
([
x
,
y
],
dim
=-
1
)
rgba
=
rgba
[
y
,
x
,
:]
else
:
pixels_xy
=
None
# full image
# Be careful: This dataset's camera coordinate is not the same as
# opencv's camera coordinate! It is actually opengl.
rays
=
generate_rays
(
cameras
,
opencv_format
=
False
,
pixels_xy
=
pixels_xy
,
image_id
=
[
index
]
x
,
y
=
torch
.
meshgrid
(
torch
.
arange
(
self
.
WIDTH
,
device
=
self
.
images
.
device
),
torch
.
arange
(
self
.
HEIGHT
,
device
=
self
.
images
.
device
),
indexing
=
"xy"
,
)
x
=
x
.
flatten
()
y
=
y
.
flatten
()
# generate rays
rgba
=
self
.
images
[
image_id
,
y
,
x
]
/
255.0
# (num_rays, 4)
c2w
=
self
.
camtoworlds
[
image_id
]
# (num_rays, 3, 4)
camera_dirs
=
F
.
pad
(
torch
.
stack
(
[
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
],
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
],
],
dim
=-
1
,
),
(
0
,
1
),
value
=
1
,
)
# [num_rays, 3]
camera_dirs
[...,
[
1
,
2
]]
*=
-
1
# opengl format
# [n_cams, height, width, 3]
directions
=
(
camera_dirs
[:,
None
,
:]
*
c2w
[:,
:
3
,
:
3
]).
sum
(
dim
=-
1
)
origins
=
torch
.
broadcast_to
(
c2w
[:,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
if
self
.
training
:
origins
=
torch
.
reshape
(
origins
,
(
self
.
num_rays
,
3
))
viewdirs
=
torch
.
reshape
(
viewdirs
,
(
self
.
num_rays
,
3
))
rgba
=
torch
.
reshape
(
rgba
,
(
self
.
num_rays
,
4
))
else
:
origins
=
torch
.
reshape
(
origins
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
3
))
viewdirs
=
torch
.
reshape
(
viewdirs
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
3
))
rgba
=
torch
.
reshape
(
rgba
,
(
self
.
HEIGHT
,
self
.
WIDTH
,
4
))
rays
=
Rays
(
origins
=
origins
,
viewdirs
=
viewdirs
)
return
{
"camera_id"
:
camera_id
,
"rgba"
:
rgba
,
# [h, w, 4] or [num_rays, 4]
"rays"
:
rays
,
# [h, w] or [num_rays,
4
]
"rays"
:
rays
,
# [h, w
, 3
] or [num_rays,
3
]
}
examples/datasets/utils.py
deleted
100644 → 0
View file @
b5a2af68
# Copyright (c) Meta Platforms, Inc. and affiliates.
import
collections
import
math
import
torch
import
torch.nn.functional
as
F
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
Cameras
=
collections
.
namedtuple
(
"Cameras"
,
(
"intrins"
,
"extrins"
,
"distorts"
,
"width"
,
"height"
)
)
def
namedtuple_map
(
fn
,
tup
):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return
type
(
tup
)(
*
(
None
if
x
is
None
else
fn
(
x
)
for
x
in
tup
))
def
homo
(
points
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Get the homogeneous coordinates."""
return
F
.
pad
(
points
,
(
0
,
1
),
value
=
1
)
def
transform_cameras
(
cameras
:
Cameras
,
resize_factor
:
float
)
->
torch
.
Tensor
:
intrins
=
cameras
.
intrins
intrins
[...,
:
2
,
:]
=
intrins
[...,
:
2
,
:]
*
resize_factor
width
=
int
(
cameras
.
width
*
resize_factor
+
0.5
)
height
=
int
(
cameras
.
height
*
resize_factor
+
0.5
)
return
Cameras
(
intrins
=
intrins
,
extrins
=
cameras
.
extrins
,
distorts
=
cameras
.
distorts
,
width
=
width
,
height
=
height
,
)
def
generate_rays
(
cameras
:
Cameras
,
opencv_format
:
bool
=
True
,
pixels_xy
:
torch
.
Tensor
=
None
,
)
->
Rays
:
"""Generating rays for a single or multiple cameras.
:params cameras [(n_cams,)]
:returns: Rays
[(n_cams,) height, width] if pixels_xy is None
[(n_cams,) num_pixels] if pixels_xy is given
"""
if
pixels_xy
is
not
None
:
K
=
cameras
.
intrins
[...,
None
,
:,
:]
c2w
=
cameras
.
extrins
[...,
None
,
:,
:].
inverse
()
x
,
y
=
pixels_xy
[...,
0
],
pixels_xy
[...,
1
]
else
:
K
=
cameras
.
intrins
[...,
None
,
None
,
:,
:]
c2w
=
cameras
.
extrins
[...,
None
,
None
,
:,
:].
inverse
()
x
,
y
=
torch
.
meshgrid
(
torch
.
arange
(
cameras
.
width
,
dtype
=
K
.
dtype
),
torch
.
arange
(
cameras
.
height
,
dtype
=
K
.
dtype
),
indexing
=
"xy"
,
)
# [height, width]
camera_dirs
=
homo
(
torch
.
stack
(
[
(
x
-
K
[...,
0
,
2
]
+
0.5
)
/
K
[...,
0
,
0
],
(
y
-
K
[...,
1
,
2
]
+
0.5
)
/
K
[...,
1
,
1
],
],
dim
=-
1
,
)
)
# [n_cams, height, width, 3]
if
not
opencv_format
:
camera_dirs
[...,
[
1
,
2
]]
*=
-
1
# [n_cams, height, width, 3]
directions
=
(
camera_dirs
[...,
None
,
:]
*
c2w
[...,
:
3
,
:
3
]).
sum
(
dim
=-
1
)
origins
=
torch
.
broadcast_to
(
c2w
[...,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
rays
=
Rays
(
origins
=
origins
,
# [n_cams, height, width, 3]
viewdirs
=
viewdirs
,
# [n_cams, height, width, 3]
)
return
rays
examples/trainval.py
View file @
6093da8f
...
...
@@ -5,8 +5,7 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
import
tqdm
from
datasets.nerf_synthetic
import
SubjectLoader
from
datasets.utils
import
namedtuple_map
from
datasets.nerf_synthetic
import
SubjectLoader
,
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
nerfacc
import
OccupancyField
,
volumetric_rendering
...
...
@@ -67,19 +66,26 @@ if __name__ == "__main__":
split
=
"train"
,
num_rays
=
8192
,
)
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
# train_dataset.K = train_dataset.K.to(device)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
num_workers
=
1
,
num_workers
=
4
,
batch_size
=
None
,
persistent_workers
=
True
,
shuffle
=
True
,
)
test_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"test"
,
num_rays
=
None
,
)
# test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
num_workers
=
4
,
...
...
@@ -102,9 +108,9 @@ if __name__ == "__main__":
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
#
scheduler = torch.optim.lr_scheduler.MultiStepLR(
#
optimizer, milestones=[
10000,
20000, 30000], gamma=0.1
#
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
20000
,
30000
],
gamma
=
0.1
)
# setup occupancy field with eval function
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -156,7 +162,7 @@ if __name__ == "__main__":
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
#
scheduler.step()
scheduler
.
step
()
if
step
%
50
==
0
:
elapsed_time
=
time
.
time
()
-
tic
...
...
@@ -189,6 +195,18 @@ if __name__ == "__main__":
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment