Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
b5a2af68
Commit
b5a2af68
authored
Sep 09, 2022
by
Ruilong Li
Browse files
speedup data loading
parent
bc7f7fff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
74 deletions
+36
-74
README.md
README.md
+2
-2
examples/datasets/base.py
examples/datasets/base.py
+2
-7
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+10
-25
examples/datasets/utils.py
examples/datasets/utils.py
+1
-27
examples/trainval.py
examples/trainval.py
+21
-13
No files found.
README.md
View file @
b5a2af68
...
@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set.
...
@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX |
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file
\ No newline at end of file
examples/datasets/base.py
View file @
b5a2af68
...
@@ -39,9 +39,8 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
...
@@ -39,9 +39,8 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
iter_start
=
worker_id
*
per_worker
iter_start
=
worker_id
*
per_worker
iter_end
=
min
(
iter_start
+
per_worker
,
self
.
__len__
())
iter_end
=
min
(
iter_start
+
per_worker
,
self
.
__len__
())
if
self
.
training
:
if
self
.
training
:
while
True
:
for
index
in
iter_start
+
torch
.
randperm
(
iter_end
-
iter_start
):
for
index
in
iter_start
+
torch
.
randperm
(
iter_end
-
iter_start
):
yield
self
.
__getitem__
(
index
)
yield
self
.
__getitem__
(
index
)
else
:
else
:
for
index
in
range
(
iter_start
,
iter_end
):
for
index
in
range
(
iter_start
,
iter_end
):
yield
self
.
__getitem__
(
index
)
yield
self
.
__getitem__
(
index
)
...
@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
...
@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
self
.
_cache
=
data
self
.
_cache
=
data
self
.
_n_repeat
=
1
self
.
_n_repeat
=
1
return
self
.
preprocess
(
data
)
return
self
.
preprocess
(
data
)
@
classmethod
def
collate_fn
(
cls
,
batch
):
return
batch
[
0
]
examples/datasets/nerf_synthetic.py
View file @
b5a2af68
...
@@ -2,13 +2,11 @@
...
@@ -2,13 +2,11 @@
import
json
import
json
import
os
import
os
import
cv2
import
imageio.v2
as
imageio
import
imageio.v2
as
imageio
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
.base
import
CachedIterDataset
from
.utils
import
Cameras
,
generate_rays
from
.utils
import
Cameras
,
generate_rays
,
transform_cameras
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
...
@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
...
@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
return
images
,
camtoworlds
,
focal
return
images
,
camtoworlds
,
focal
class
SubjectLoader
(
CachedIter
Dataset
):
class
SubjectLoader
(
torch
.
utils
.
data
.
Dataset
):
"""Single subject data loader for training and evaluation."""
"""Single subject data loader for training and evaluation."""
SPLITS
=
[
"train"
,
"val"
,
"trainval"
,
"test"
]
SPLITS
=
[
"train"
,
"val"
,
"trainval"
,
"test"
]
...
@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset):
...
@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset):
subject_id
:
str
,
subject_id
:
str
,
root_fp
:
str
,
root_fp
:
str
,
split
:
str
,
split
:
str
,
resize_factor
:
float
=
1.0
,
color_bkgd_aug
:
str
=
"white"
,
color_bkgd_aug
:
str
=
"white"
,
num_rays
:
int
=
None
,
num_rays
:
int
=
None
,
cache_n_repeat
:
int
=
0
,
near
:
float
=
None
,
near
:
float
=
None
,
far
:
float
=
None
,
far
:
float
=
None
,
):
):
super
().
__init__
()
assert
split
in
self
.
SPLITS
,
"%s"
%
split
assert
split
in
self
.
SPLITS
,
"%s"
%
split
assert
subject_id
in
self
.
SUBJECT_IDS
,
"%s"
%
subject_id
assert
subject_id
in
self
.
SUBJECT_IDS
,
"%s"
%
subject_id
assert
color_bkgd_aug
in
[
"white"
,
"black"
,
"random"
]
assert
color_bkgd_aug
in
[
"white"
,
"black"
,
"random"
]
self
.
resize_factor
=
resize_factor
self
.
split
=
split
self
.
split
=
split
self
.
num_rays
=
num_rays
self
.
num_rays
=
num_rays
self
.
near
=
self
.
NEAR
if
near
is
None
else
near
self
.
near
=
self
.
NEAR
if
near
is
None
else
near
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
])
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
,
"trainval"
])
self
.
color_bkgd_aug
=
color_bkgd_aug
self
.
color_bkgd_aug
=
color_bkgd_aug
if
split
==
"trainval"
:
if
split
==
"trainval"
:
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
...
@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset):
...
@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset):
root_fp
,
subject_id
,
split
root_fp
,
subject_id
,
split
)
)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
super
().
__init__
(
self
.
training
,
cache_n_repeat
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
images
)
return
len
(
self
.
images
)
# @profile
def
__getitem__
(
self
,
index
):
data
=
self
.
fetch_data
(
index
)
data
=
self
.
preprocess
(
data
)
return
data
def
preprocess
(
self
,
data
):
def
preprocess
(
self
,
data
):
"""Process the fetched / cached data with randomness."""
"""Process the fetched / cached data with randomness."""
rgba
,
rays
=
data
[
"rgba"
],
data
[
"rays"
]
rgba
,
rays
=
data
[
"rgba"
],
data
[
"rays"
]
...
@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset):
...
@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset):
rgba
=
self
.
images
[
camera_id
]
rgba
=
self
.
images
[
camera_id
]
# create pixels
# create pixels
rgba
=
(
rgba
=
torch
.
from_numpy
(
rgba
).
float
()
/
255.0
torch
.
from_numpy
(
cv2
.
resize
(
rgba
,
(
0
,
0
),
fx
=
self
.
resize_factor
,
fy
=
self
.
resize_factor
,
interpolation
=
cv2
.
INTER_AREA
,
)
).
float
()
/
255.0
)
# create rays from camera
# create rays from camera
cameras
=
Cameras
(
cameras
=
Cameras
(
...
@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset):
...
@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset):
width
=
self
.
WIDTH
,
width
=
self
.
WIDTH
,
height
=
self
.
HEIGHT
,
height
=
self
.
HEIGHT
,
)
)
cameras
=
transform_cameras
(
cameras
,
self
.
resize_factor
)
if
self
.
num_rays
is
not
None
:
if
self
.
num_rays
is
not
None
:
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,))
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,))
...
@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset):
...
@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset):
rays
=
generate_rays
(
rays
=
generate_rays
(
cameras
,
cameras
,
opencv_format
=
False
,
opencv_format
=
False
,
near
=
self
.
near
,
far
=
self
.
far
,
pixels_xy
=
pixels_xy
,
pixels_xy
=
pixels_xy
,
)
)
...
...
examples/datasets/utils.py
View file @
b5a2af68
...
@@ -5,9 +5,7 @@ import math
...
@@ -5,9 +5,7 @@ import math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
Rays
=
collections
.
namedtuple
(
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
"Rays"
,
(
"origins"
,
"directions"
,
"viewdirs"
,
"radii"
,
"near"
,
"far"
)
)
Cameras
=
collections
.
namedtuple
(
Cameras
=
collections
.
namedtuple
(
"Cameras"
,
(
"intrins"
,
"extrins"
,
"distorts"
,
"width"
,
"height"
)
"Cameras"
,
(
"intrins"
,
"extrins"
,
"distorts"
,
"width"
,
"height"
)
...
@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
...
@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
def
generate_rays
(
def
generate_rays
(
cameras
:
Cameras
,
cameras
:
Cameras
,
opencv_format
:
bool
=
True
,
opencv_format
:
bool
=
True
,
near
:
float
=
None
,
far
:
float
=
None
,
pixels_xy
:
torch
.
Tensor
=
None
,
pixels_xy
:
torch
.
Tensor
=
None
,
)
->
Rays
:
)
->
Rays
:
"""Generating rays for a single or multiple cameras.
"""Generating rays for a single or multiple cameras.
...
@@ -82,30 +78,8 @@ def generate_rays(
...
@@ -82,30 +78,8 @@ def generate_rays(
origins
=
torch
.
broadcast_to
(
c2w
[...,
:
3
,
-
1
],
directions
.
shape
)
origins
=
torch
.
broadcast_to
(
c2w
[...,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
if
pixels_xy
is
None
:
# Distance from each unit-norm direction vector to its x-axis neighbor.
dx
=
torch
.
sqrt
(
torch
.
sum
(
(
directions
[...,
:
-
1
,
:,
:]
-
directions
[...,
1
:,
:,
:])
**
2
,
dim
=-
1
,
)
)
dx
=
torch
.
cat
([
dx
,
dx
[...,
-
2
:
-
1
,
:]],
dim
=-
2
)
radii
=
dx
[...,
None
]
*
2
/
math
.
sqrt
(
12
)
# [n_cams, height, width, 1]
else
:
radii
=
None
if
near
is
not
None
:
near
=
near
*
torch
.
ones_like
(
origins
[...,
0
:
1
])
if
far
is
not
None
:
far
=
far
*
torch
.
ones_like
(
origins
[...,
0
:
1
])
rays
=
Rays
(
rays
=
Rays
(
origins
=
origins
,
# [n_cams, height, width, 3]
origins
=
origins
,
# [n_cams, height, width, 3]
directions
=
directions
,
# [n_cams, height, width, 3]
viewdirs
=
viewdirs
,
# [n_cams, height, width, 3]
viewdirs
=
viewdirs
,
# [n_cams, height, width, 3]
radii
=
radii
,
# [n_cams, height, width, 1]
# near far is not needed when they are estimated by skeleton.
near
=
near
,
far
=
far
,
)
)
return
rays
return
rays
examples/trainval.py
View file @
b5a2af68
...
@@ -64,14 +64,15 @@ if __name__ == "__main__":
...
@@ -64,14 +64,15 @@ if __name__ == "__main__":
train_dataset
=
SubjectLoader
(
train_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
subject_id
=
"lego"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"
val
"
,
split
=
"
train
"
,
num_rays
=
8192
,
num_rays
=
8192
,
)
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
train_dataset
,
num_workers
=
10
,
num_workers
=
1
,
batch_size
=
1
,
batch_size
=
None
,
collate_fn
=
getattr
(
train_dataset
.
__class__
,
"collate_fn"
),
persistent_workers
=
True
,
shuffle
=
True
,
)
)
test_dataset
=
SubjectLoader
(
test_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
subject_id
=
"lego"
,
...
@@ -81,9 +82,8 @@ if __name__ == "__main__":
...
@@ -81,9 +82,8 @@ if __name__ == "__main__":
)
)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
test_dataset
,
num_workers
=
10
,
num_workers
=
4
,
batch_size
=
1
,
batch_size
=
None
,
collate_fn
=
getattr
(
train_dataset
.
__class__
,
"collate_fn"
),
)
)
# setup the scene bounding box.
# setup the scene bounding box.
...
@@ -102,6 +102,9 @@ if __name__ == "__main__":
...
@@ -102,6 +102,9 @@ if __name__ == "__main__":
)
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1
# )
# setup occupancy field with eval function
# setup occupancy field with eval function
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -125,8 +128,11 @@ if __name__ == "__main__":
...
@@ -125,8 +128,11 @@ if __name__ == "__main__":
# training
# training
step
=
0
step
=
0
tic
=
time
.
time
()
tic
=
time
.
time
()
for
epoch
in
range
(
200
):
data_time
=
0
tic_data
=
time
.
time
()
for
epoch
in
range
(
300
):
for
data
in
train_dataloader
:
for
data
in
train_dataloader
:
data_time
+=
time
.
time
()
-
tic_data
step
+=
1
step
+=
1
if
step
>
30_000
:
if
step
>
30_000
:
print
(
"training stops"
)
print
(
"training stops"
)
...
@@ -150,11 +156,12 @@ if __name__ == "__main__":
...
@@ -150,11 +156,12 @@ if __name__ == "__main__":
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
# scheduler.step()
if
step
%
50
==
0
:
if
step
%
50
==
0
:
elapsed_time
=
time
.
time
()
-
tic
elapsed_time
=
time
.
time
()
-
tic
print
(
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s |
{
step
=
}
| loss=
{
loss
.
item
():
.
5
f
}
"
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s
(data=
{
data_time
:.
2
f
}
s)
|
{
step
=
}
| loss=
{
loss
:
.
5
f
}
"
)
)
if
step
%
30_000
==
0
and
step
>
0
:
if
step
%
30_000
==
0
and
step
>
0
:
...
@@ -176,11 +183,12 @@ if __name__ == "__main__":
...
@@ -176,11 +183,12 @@ if __name__ == "__main__":
psnrs
.
append
(
psnr
.
item
())
psnrs
.
append
(
psnr
.
item
())
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
tic_data
=
time
.
time
()
# "train"
# "train"
# elapsed_time=
317.59s
| step=30000 | loss=
0.0002
8
# elapsed_time=
298.27s (data=60.08s)
| step=30000 | loss=0.0002
6
# evaluation: psnr_avg=33.
2709695911407
5 (6.
2
4 it/s)
# evaluation: psnr_avg=33.
30533466339111
5 (6.4
2
it/s)
# "trainval"
# "trainval"
# elapsed_time=
3
89.
08s
| step=30000 | loss=
0.000
30
# elapsed_time=
2
89.
94s (data=51.99s)
| step=30000 | loss=0.000
21
# evaluation: psnr_avg=34.
00573859214783
(6.
2
6 it/s)
# evaluation: psnr_avg=34.
44980221748352
(6.6
1
it/s)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment