Commit b5a2af68 authored by Ruilong Li's avatar Ruilong Li
Browse files

speedup data loading

parent bc7f7fff
...@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set. ...@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - | | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 | | instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 | | torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX | | ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX | | ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file \ No newline at end of file
...@@ -39,9 +39,8 @@ class CachedIterDataset(torch.utils.data.IterableDataset): ...@@ -39,9 +39,8 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
iter_start = worker_id * per_worker iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, self.__len__()) iter_end = min(iter_start + per_worker, self.__len__())
if self.training: if self.training:
while True: for index in iter_start + torch.randperm(iter_end - iter_start):
for index in iter_start + torch.randperm(iter_end - iter_start): yield self.__getitem__(index)
yield self.__getitem__(index)
else: else:
for index in range(iter_start, iter_end): for index in range(iter_start, iter_end):
yield self.__getitem__(index) yield self.__getitem__(index)
...@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset): ...@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
self._cache = data self._cache = data
self._n_repeat = 1 self._n_repeat = 1
return self.preprocess(data) return self.preprocess(data)
@classmethod
def collate_fn(cls, batch):
return batch[0]
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
import json import json
import os import os
import cv2
import imageio.v2 as imageio import imageio.v2 as imageio
import numpy as np import numpy as np
import torch import torch
from .base import CachedIterDataset from .utils import Cameras, generate_rays
from .utils import Cameras, generate_rays, transform_cameras
def _load_renderings(root_fp: str, subject_id: str, split: str): def _load_renderings(root_fp: str, subject_id: str, split: str):
...@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
return images, camtoworlds, focal return images, camtoworlds, focal
class SubjectLoader(CachedIterDataset): class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation.""" """Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "trainval", "test"] SPLITS = ["train", "val", "trainval", "test"]
...@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset): ...@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset):
subject_id: str, subject_id: str,
root_fp: str, root_fp: str,
split: str, split: str,
resize_factor: float = 1.0,
color_bkgd_aug: str = "white", color_bkgd_aug: str = "white",
num_rays: int = None, num_rays: int = None,
cache_n_repeat: int = 0,
near: float = None, near: float = None,
far: float = None, far: float = None,
): ):
super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"] assert color_bkgd_aug in ["white", "black", "random"]
self.resize_factor = resize_factor
self.split = split self.split = split
self.num_rays = num_rays self.num_rays = num_rays
self.near = self.NEAR if near is None else near self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train"]) self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
if split == "trainval": if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings( _images_train, _camtoworlds_train, _focal_train = _load_renderings(
...@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset): ...@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset):
root_fp, subject_id, split root_fp, subject_id, split
) )
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
super().__init__(self.training, cache_n_repeat)
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)
# @profile def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data
def preprocess(self, data): def preprocess(self, data):
"""Process the fetched / cached data with randomness.""" """Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"] rgba, rays = data["rgba"], data["rays"]
...@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset): ...@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset):
rgba = self.images[camera_id] rgba = self.images[camera_id]
# create pixels # create pixels
rgba = ( rgba = torch.from_numpy(rgba).float() / 255.0
torch.from_numpy(
cv2.resize(
rgba,
(0, 0),
fx=self.resize_factor,
fy=self.resize_factor,
interpolation=cv2.INTER_AREA,
)
).float()
/ 255.0
)
# create rays from camera # create rays from camera
cameras = Cameras( cameras = Cameras(
...@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset): ...@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset):
width=self.WIDTH, width=self.WIDTH,
height=self.HEIGHT, height=self.HEIGHT,
) )
cameras = transform_cameras(cameras, self.resize_factor)
if self.num_rays is not None: if self.num_rays is not None:
x = torch.randint(0, self.WIDTH, size=(self.num_rays,)) x = torch.randint(0, self.WIDTH, size=(self.num_rays,))
...@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset): ...@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset):
rays = generate_rays( rays = generate_rays(
cameras, cameras,
opencv_format=False, opencv_format=False,
near=self.near,
far=self.far,
pixels_xy=pixels_xy, pixels_xy=pixels_xy,
) )
......
...@@ -5,9 +5,7 @@ import math ...@@ -5,9 +5,7 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
Rays = collections.namedtuple( Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
"Rays", ("origins", "directions", "viewdirs", "radii", "near", "far")
)
Cameras = collections.namedtuple( Cameras = collections.namedtuple(
"Cameras", ("intrins", "extrins", "distorts", "width", "height") "Cameras", ("intrins", "extrins", "distorts", "width", "height")
...@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor: ...@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
def generate_rays( def generate_rays(
cameras: Cameras, cameras: Cameras,
opencv_format: bool = True, opencv_format: bool = True,
near: float = None,
far: float = None,
pixels_xy: torch.Tensor = None, pixels_xy: torch.Tensor = None,
) -> Rays: ) -> Rays:
"""Generating rays for a single or multiple cameras. """Generating rays for a single or multiple cameras.
...@@ -82,30 +78,8 @@ def generate_rays( ...@@ -82,30 +78,8 @@ def generate_rays(
origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape) origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True) viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if pixels_xy is None:
# Distance from each unit-norm direction vector to its x-axis neighbor.
dx = torch.sqrt(
torch.sum(
(directions[..., :-1, :, :] - directions[..., 1:, :, :]) ** 2,
dim=-1,
)
)
dx = torch.cat([dx, dx[..., -2:-1, :]], dim=-2)
radii = dx[..., None] * 2 / math.sqrt(12) # [n_cams, height, width, 1]
else:
radii = None
if near is not None:
near = near * torch.ones_like(origins[..., 0:1])
if far is not None:
far = far * torch.ones_like(origins[..., 0:1])
rays = Rays( rays = Rays(
origins=origins, # [n_cams, height, width, 3] origins=origins, # [n_cams, height, width, 3]
directions=directions, # [n_cams, height, width, 3]
viewdirs=viewdirs, # [n_cams, height, width, 3] viewdirs=viewdirs, # [n_cams, height, width, 3]
radii=radii, # [n_cams, height, width, 1]
# near far is not needed when they are estimated by skeleton.
near=near,
far=far,
) )
return rays return rays
...@@ -64,14 +64,15 @@ if __name__ == "__main__": ...@@ -64,14 +64,15 @@ if __name__ == "__main__":
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="val", split="train",
num_rays=8192, num_rays=8192,
) )
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
num_workers=10, num_workers=1,
batch_size=1, batch_size=None,
collate_fn=getattr(train_dataset.__class__, "collate_fn"), persistent_workers=True,
shuffle=True,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
...@@ -81,9 +82,8 @@ if __name__ == "__main__": ...@@ -81,9 +82,8 @@ if __name__ == "__main__":
) )
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_dataset, test_dataset,
num_workers=10, num_workers=4,
batch_size=1, batch_size=None,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
) )
# setup the scene bounding box. # setup the scene bounding box.
...@@ -102,6 +102,9 @@ if __name__ == "__main__": ...@@ -102,6 +102,9 @@ if __name__ == "__main__":
) )
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1
# )
# setup occupancy field with eval function # setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor: def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
...@@ -125,8 +128,11 @@ if __name__ == "__main__": ...@@ -125,8 +128,11 @@ if __name__ == "__main__":
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
for epoch in range(200): data_time = 0
tic_data = time.time()
for epoch in range(300):
for data in train_dataloader: for data in train_dataloader:
data_time += time.time() - tic_data
step += 1 step += 1
if step > 30_000: if step > 30_000:
print("training stops") print("training stops")
...@@ -150,11 +156,12 @@ if __name__ == "__main__": ...@@ -150,11 +156,12 @@ if __name__ == "__main__":
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# scheduler.step()
if step % 50 == 0: if step % 50 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
print( print(
f"elapsed_time={elapsed_time:.2f}s | {step=} | loss={loss.item(): .5f}" f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | loss={loss:.5f}"
) )
if step % 30_000 == 0 and step > 0: if step % 30_000 == 0 and step > 0:
...@@ -176,11 +183,12 @@ if __name__ == "__main__": ...@@ -176,11 +183,12 @@ if __name__ == "__main__":
psnrs.append(psnr.item()) psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}") print(f"evaluation: {psnr_avg=}")
tic_data = time.time()
# "train" # "train"
# elapsed_time=317.59s | step=30000 | loss= 0.00028 # elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.27096959114075 (6.24 it/s) # evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "trainval" # "trainval"
# elapsed_time=389.08s | step=30000 | loss= 0.00030 # elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.00573859214783 (6.26 it/s) # evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment