"torchvision/vscode:/vscode.git/clone" did not exist on "08743385d93a0aa4145da6f6db49bfa00df148a3"
Commit b5a2af68 authored by Ruilong Li's avatar Ruilong Li
Browse files

speedup data loading

parent bc7f7fff
......@@ -16,5 +16,5 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX |
\ No newline at end of file
| ours | train (30K steps) | 33.31 | 298 sec | 6.4 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.45 | 290 sec | 6.6 fps | TITAN RTX |
\ No newline at end of file
......@@ -39,7 +39,6 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, self.__len__())
if self.training:
while True:
for index in iter_start + torch.randperm(iter_end - iter_start):
yield self.__getitem__(index)
else:
......@@ -59,7 +58,3 @@ class CachedIterDataset(torch.utils.data.IterableDataset):
self._cache = data
self._n_repeat = 1
return self.preprocess(data)
@classmethod
def collate_fn(cls, batch):
return batch[0]
......@@ -2,13 +2,11 @@
import json
import os
import cv2
import imageio.v2 as imageio
import numpy as np
import torch
from .base import CachedIterDataset
from .utils import Cameras, generate_rays, transform_cameras
from .utils import Cameras, generate_rays
def _load_renderings(root_fp: str, subject_id: str, split: str):
......@@ -45,7 +43,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
return images, camtoworlds, focal
class SubjectLoader(CachedIterDataset):
class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "trainval", "test"]
......@@ -67,22 +65,20 @@ class SubjectLoader(CachedIterDataset):
subject_id: str,
root_fp: str,
split: str,
resize_factor: float = 1.0,
color_bkgd_aug: str = "white",
num_rays: int = None,
cache_n_repeat: int = 0,
near: float = None,
far: float = None,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.resize_factor = resize_factor
self.split = split
self.num_rays = num_rays
self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train"])
self.training = (num_rays is not None) and (split in ["train", "trainval"])
self.color_bkgd_aug = color_bkgd_aug
if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings(
......@@ -99,12 +95,15 @@ class SubjectLoader(CachedIterDataset):
root_fp, subject_id, split
)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
super().__init__(self.training, cache_n_repeat)
def __len__(self):
return len(self.images)
# @profile
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"]
......@@ -144,18 +143,7 @@ class SubjectLoader(CachedIterDataset):
rgba = self.images[camera_id]
# create pixels
rgba = (
torch.from_numpy(
cv2.resize(
rgba,
(0, 0),
fx=self.resize_factor,
fy=self.resize_factor,
interpolation=cv2.INTER_AREA,
)
).float()
/ 255.0
)
rgba = torch.from_numpy(rgba).float() / 255.0
# create rays from camera
cameras = Cameras(
......@@ -165,7 +153,6 @@ class SubjectLoader(CachedIterDataset):
width=self.WIDTH,
height=self.HEIGHT,
)
cameras = transform_cameras(cameras, self.resize_factor)
if self.num_rays is not None:
x = torch.randint(0, self.WIDTH, size=(self.num_rays,))
......@@ -180,8 +167,6 @@ class SubjectLoader(CachedIterDataset):
rays = generate_rays(
cameras,
opencv_format=False,
near=self.near,
far=self.far,
pixels_xy=pixels_xy,
)
......
......@@ -5,9 +5,7 @@ import math
import torch
import torch.nn.functional as F
Rays = collections.namedtuple(
"Rays", ("origins", "directions", "viewdirs", "radii", "near", "far")
)
Rays = collections.namedtuple("Rays", ("origins", "viewdirs"))
Cameras = collections.namedtuple(
"Cameras", ("intrins", "extrins", "distorts", "width", "height")
......@@ -41,8 +39,6 @@ def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
def generate_rays(
cameras: Cameras,
opencv_format: bool = True,
near: float = None,
far: float = None,
pixels_xy: torch.Tensor = None,
) -> Rays:
"""Generating rays for a single or multiple cameras.
......@@ -82,30 +78,8 @@ def generate_rays(
origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if pixels_xy is None:
# Distance from each unit-norm direction vector to its x-axis neighbor.
dx = torch.sqrt(
torch.sum(
(directions[..., :-1, :, :] - directions[..., 1:, :, :]) ** 2,
dim=-1,
)
)
dx = torch.cat([dx, dx[..., -2:-1, :]], dim=-2)
radii = dx[..., None] * 2 / math.sqrt(12) # [n_cams, height, width, 1]
else:
radii = None
if near is not None:
near = near * torch.ones_like(origins[..., 0:1])
if far is not None:
far = far * torch.ones_like(origins[..., 0:1])
rays = Rays(
origins=origins, # [n_cams, height, width, 3]
directions=directions, # [n_cams, height, width, 3]
viewdirs=viewdirs, # [n_cams, height, width, 3]
radii=radii, # [n_cams, height, width, 1]
# near far is not needed when they are estimated by skeleton.
near=near,
far=far,
)
return rays
......@@ -64,14 +64,15 @@ if __name__ == "__main__":
train_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="val",
split="train",
num_rays=8192,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
num_workers=10,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
num_workers=1,
batch_size=None,
persistent_workers=True,
shuffle=True,
)
test_dataset = SubjectLoader(
subject_id="lego",
......@@ -81,9 +82,8 @@ if __name__ == "__main__":
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=10,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
num_workers=4,
batch_size=None,
)
# setup the scene bounding box.
......@@ -102,6 +102,9 @@ if __name__ == "__main__":
)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[10000, 20000, 30000], gamma=0.1
# )
# setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
......@@ -125,8 +128,11 @@ if __name__ == "__main__":
# training
step = 0
tic = time.time()
for epoch in range(200):
data_time = 0
tic_data = time.time()
for epoch in range(300):
for data in train_dataloader:
data_time += time.time() - tic_data
step += 1
if step > 30_000:
print("training stops")
......@@ -150,11 +156,12 @@ if __name__ == "__main__":
optimizer.zero_grad()
loss.backward()
optimizer.step()
# scheduler.step()
if step % 50 == 0:
elapsed_time = time.time() - tic
print(
f"elapsed_time={elapsed_time:.2f}s | {step=} | loss={loss.item(): .5f}"
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | loss={loss:.5f}"
)
if step % 30_000 == 0 and step > 0:
......@@ -176,11 +183,12 @@ if __name__ == "__main__":
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
tic_data = time.time()
# "train"
# elapsed_time=317.59s | step=30000 | loss= 0.00028
# evaluation: psnr_avg=33.27096959114075 (6.24 it/s)
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "trainval"
# elapsed_time=389.08s | step=30000 | loss= 0.00030
# evaluation: psnr_avg=34.00573859214783 (6.26 it/s)
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment