Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
b5a2af68
"torchvision/vscode:/vscode.git/clone" did not exist on "08743385d93a0aa4145da6f6db49bfa00df148a3"
Commit
b5a2af68
authored
Sep 09, 2022
by
Ruilong Li
Browse files
speedup data loading
parent
bc7f7fff
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
74 deletions
+36
-74
README.md
README.md
+2
-2
examples/datasets/base.py
examples/datasets/base.py
+2
-7
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+10
-25
examples/datasets/utils.py
examples/datasets/utils.py
+1
-27
examples/trainval.py
examples/trainval.py
+21
-13
No files found.
README.md
View file @
b5a2af68
...
...
@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (
`-O`
) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX |
\ No newline at end of file
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file
examples/datasets/base.py
View file @
b5a2af68
...
...
@@ -39,7 +39,6 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
iter_start
=
worker_id
*
per_worker
iter_end
=
min
(
iter_start
+
per_worker
,
self
.
__len__
())
if
self
.
training
:
while
True
:
for
index
in
iter_start
+
torch
.
randperm
(
iter_end
-
iter_start
):
yield
self
.
__getitem__
(
index
)
else
:
...
...
@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
self
.
_cache
=
data
self
.
_n_repeat
=
1
return
self
.
preprocess
(
data
)
@
classmethod
def
collate_fn
(
cls
,
batch
):
return
batch
[
0
]
examples/datasets/nerf_synthetic.py
View file @
b5a2af68
...
...
@@ -2,13 +2,11 @@
import
json
import
os
import
cv2
import
imageio.v2
as
imageio
import
numpy
as
np
import
torch
from
.base
import
CachedIterDataset
from
.utils
import
Cameras
,
generate_rays
,
transform_cameras
from
.utils
import
Cameras
,
generate_rays
def
_load_renderings
(
root_fp
:
str
,
subject_id
:
str
,
split
:
str
):
...
...
@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
return
images
,
camtoworlds
,
focal
class
SubjectLoader
(
CachedIter
Dataset
):
class
SubjectLoader
(
torch
.
utils
.
data
.
Dataset
):
"""Single subject data loader for training and evaluation."""
SPLITS
=
[
"train"
,
"val"
,
"trainval"
,
"test"
]
...
...
@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset):
subject_id
:
str
,
root_fp
:
str
,
split
:
str
,
resize_factor
:
float
=
1.0
,
color_bkgd_aug
:
str
=
"white"
,
num_rays
:
int
=
None
,
cache_n_repeat
:
int
=
0
,
near
:
float
=
None
,
far
:
float
=
None
,
):
super
().
__init__
()
assert
split
in
self
.
SPLITS
,
"%s"
%
split
assert
subject_id
in
self
.
SUBJECT_IDS
,
"%s"
%
subject_id
assert
color_bkgd_aug
in
[
"white"
,
"black"
,
"random"
]
self
.
resize_factor
=
resize_factor
self
.
split
=
split
self
.
num_rays
=
num_rays
self
.
near
=
self
.
NEAR
if
near
is
None
else
near
self
.
far
=
self
.
FAR
if
far
is
None
else
far
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
])
self
.
training
=
(
num_rays
is
not
None
)
and
(
split
in
[
"train"
,
"trainval"
])
self
.
color_bkgd_aug
=
color_bkgd_aug
if
split
==
"trainval"
:
_images_train
,
_camtoworlds_train
,
_focal_train
=
_load_renderings
(
...
...
@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset):
root_fp
,
subject_id
,
split
)
assert
self
.
images
.
shape
[
1
:
3
]
==
(
self
.
HEIGHT
,
self
.
WIDTH
)
super
().
__init__
(
self
.
training
,
cache_n_repeat
)
def
__len__
(
self
):
return
len
(
self
.
images
)
# @profile
def
__getitem__
(
self
,
index
):
data
=
self
.
fetch_data
(
index
)
data
=
self
.
preprocess
(
data
)
return
data
def
preprocess
(
self
,
data
):
"""Process the fetched / cached data with randomness."""
rgba
,
rays
=
data
[
"rgba"
],
data
[
"rays"
]
...
...
@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset):
rgba
=
self
.
images
[
camera_id
]
# create pixels
rgba
=
(
torch
.
from_numpy
(
cv2
.
resize
(
rgba
,
(
0
,
0
),
fx
=
self
.
resize_factor
,
fy
=
self
.
resize_factor
,
interpolation
=
cv2
.
INTER_AREA
,
)
).
float
()
/
255.0
)
rgba
=
torch
.
from_numpy
(
rgba
).
float
()
/
255.0
# create rays from camera
cameras
=
Cameras
(
...
...
@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset):
width
=
self
.
WIDTH
,
height
=
self
.
HEIGHT
,
)
cameras
=
transform_cameras
(
cameras
,
self
.
resize_factor
)
if
self
.
num_rays
is
not
None
:
x
=
torch
.
randint
(
0
,
self
.
WIDTH
,
size
=
(
self
.
num_rays
,))
...
...
@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset):
rays
=
generate_rays
(
cameras
,
opencv_format
=
False
,
near
=
self
.
near
,
far
=
self
.
far
,
pixels_xy
=
pixels_xy
,
)
...
...
examples/datasets/utils.py
View file @
b5a2af68
...
...
@@ -5,9 +5,7 @@ import math
import
torch
import
torch.nn.functional
as
F
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"directions"
,
"viewdirs"
,
"radii"
,
"near"
,
"far"
)
)
Rays
=
collections
.
namedtuple
(
"Rays"
,
(
"origins"
,
"viewdirs"
))
Cameras
=
collections
.
namedtuple
(
"Cameras"
,
(
"intrins"
,
"extrins"
,
"distorts"
,
"width"
,
"height"
)
...
...
@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
def
generate_rays
(
cameras
:
Cameras
,
opencv_format
:
bool
=
True
,
near
:
float
=
None
,
far
:
float
=
None
,
pixels_xy
:
torch
.
Tensor
=
None
,
)
->
Rays
:
"""Generating rays for a single or multiple cameras.
...
...
@@ -82,30 +78,8 @@ def generate_rays(
origins
=
torch
.
broadcast_to
(
c2w
[...,
:
3
,
-
1
],
directions
.
shape
)
viewdirs
=
directions
/
torch
.
linalg
.
norm
(
directions
,
dim
=-
1
,
keepdims
=
True
)
if
pixels_xy
is
None
:
# Distance from each unit-norm direction vector to its x-axis neighbor.
dx
=
torch
.
sqrt
(
torch
.
sum
(
(
directions
[...,
:
-
1
,
:,
:]
-
directions
[...,
1
:,
:,
:])
**
2
,
dim
=-
1
,
)
)
dx
=
torch
.
cat
([
dx
,
dx
[...,
-
2
:
-
1
,
:]],
dim
=-
2
)
radii
=
dx
[...,
None
]
*
2
/
math
.
sqrt
(
12
)
# [n_cams, height, width, 1]
else
:
radii
=
None
if
near
is
not
None
:
near
=
near
*
torch
.
ones_like
(
origins
[...,
0
:
1
])
if
far
is
not
None
:
far
=
far
*
torch
.
ones_like
(
origins
[...,
0
:
1
])
rays
=
Rays
(
origins
=
origins
,
# [n_cams, height, width, 3]
directions
=
directions
,
# [n_cams, height, width, 3]
viewdirs
=
viewdirs
,
# [n_cams, height, width, 3]
radii
=
radii
,
# [n_cams, height, width, 1]
# near far is not needed when they are estimated by skeleton.
near
=
near
,
far
=
far
,
)
return
rays
examples/trainval.py
View file @
b5a2af68
...
...
@@ -64,14 +64,15 @@ if __name__ == "__main__":
train_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"
val
"
,
split
=
"
train
"
,
num_rays
=
8192
,
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
num_workers
=
10
,
batch_size
=
1
,
collate_fn
=
getattr
(
train_dataset
.
__class__
,
"collate_fn"
),
num_workers
=
1
,
batch_size
=
None
,
persistent_workers
=
True
,
shuffle
=
True
,
)
test_dataset
=
SubjectLoader
(
subject_id
=
"lego"
,
...
...
@@ -81,9 +82,8 @@ if __name__ == "__main__":
)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
num_workers
=
10
,
batch_size
=
1
,
collate_fn
=
getattr
(
train_dataset
.
__class__
,
"collate_fn"
),
num_workers
=
4
,
batch_size
=
None
,
)
# setup the scene bounding box.
...
...
@@ -102,6 +102,9 @@ if __name__ == "__main__":
)
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
lr
=
3e-3
,
eps
=
1e-15
)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1
# )
# setup occupancy field with eval function
def
occ_eval_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -125,8 +128,11 @@ if __name__ == "__main__":
# training
step
=
0
tic
=
time
.
time
()
for
epoch
in
range
(
200
):
data_time
=
0
tic_data
=
time
.
time
()
for
epoch
in
range
(
300
):
for
data
in
train_dataloader
:
data_time
+=
time
.
time
()
-
tic_data
step
+=
1
if
step
>
30_000
:
print
(
"training stops"
)
...
...
@@ -150,11 +156,12 @@ if __name__ == "__main__":
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
# scheduler.step()
if
step
%
50
==
0
:
elapsed_time
=
time
.
time
()
-
tic
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s |
{
step
=
}
| loss=
{
loss
.
item
():
.
5
f
}
"
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s
(data=
{
data_time
:.
2
f
}
s)
|
{
step
=
}
| loss=
{
loss
:
.
5
f
}
"
)
if
step
%
30_000
==
0
and
step
>
0
:
...
...
@@ -176,11 +183,12 @@ if __name__ == "__main__":
psnrs
.
append
(
psnr
.
item
())
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
tic_data
=
time
.
time
()
# "train"
# elapsed_time=
317.59s
| step=30000 | loss=
0.0002
8
# evaluation: psnr_avg=33.
2709695911407
5 (6.
2
4 it/s)
# elapsed_time=
298.27s (data=60.08s)
| step=30000 | loss=0.0002
6
# evaluation: psnr_avg=33.
30533466339111
5 (6.4
2
it/s)
# "trainval"
# elapsed_time=
3
89.
08s
| step=30000 | loss=
0.000
30
# evaluation: psnr_avg=34.
00573859214783
(6.
2
6 it/s)
# elapsed_time=
2
89.
94s (data=51.99s)
| step=30000 | loss=0.000
21
# evaluation: psnr_avg=34.
44980221748352
(6.6
1
it/s)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment