Unverified Commit 82fd69c7 authored by Jingchen Ye's avatar Jingchen Ye Committed by GitHub
Browse files

Updates on examples (#174)

* move torch data transfer into dataloader

* Update README

* use args.data_root

* Remove redundant check

* Fix isort

* Fix black
parent 15330b4c
...@@ -132,7 +132,7 @@ optimizer.step() ...@@ -132,7 +132,7 @@ optimizer.step()
## Examples: ## Examples:
Before running those example scripts, please check the script about which dataset it is needed, and download the dataset first. Before running those example scripts, please check the script about which dataset is needed, and download the dataset first. You could use `--data_root` to specify the path.
```bash ```bash
# clone the repo with submodules. # clone the repo with submodules.
......
...@@ -86,6 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -86,6 +86,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -106,11 +107,15 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -106,11 +107,15 @@ class SubjectLoader(torch.utils.data.Dataset):
self.focal, self.focal,
self.timestamps, self.timestamps,
) = _load_renderings(root_fp, subject_id, split) ) = _load_renderings(root_fp, subject_id, split)
self.images = torch.from_numpy(self.images).to(torch.uint8) self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32) self.camtoworlds = (
self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[ torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
:, None )
] self.timestamps = (
torch.from_numpy(self.timestamps)
.to(device)
.to(torch.float32)[:, None]
)
self.K = torch.tensor( self.K = torch.tensor(
[ [
[self.focal, 0, self.WIDTH / 2.0], [self.focal, 0, self.WIDTH / 2.0],
...@@ -118,6 +123,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -118,6 +123,7 @@ class SubjectLoader(torch.utils.data.Dataset):
[0, 0, 1], [0, 0, 1],
], ],
dtype=torch.float32, dtype=torch.float32,
device=device,
) # (3, 3) ) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
......
...@@ -169,6 +169,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -169,6 +169,7 @@ class SubjectLoader(torch.utils.data.Dataset):
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
factor: int = 1, factor: int = 1,
device: str = "cuda:0",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -186,9 +187,11 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -186,9 +187,11 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.K = _load_colmap( self.images, self.camtoworlds, self.K = _load_colmap(
root_fp, subject_id, split, factor root_fp, subject_id, split, factor
) )
self.images = torch.from_numpy(self.images).to(torch.uint8) self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32) self.camtoworlds = (
self.K = torch.tensor(self.K).to(torch.float32) torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.K = torch.tensor(self.K).to(device).to(torch.float32)
self.height, self.width = self.images.shape[1:3] self.height, self.width = self.images.shape[1:3]
def __len__(self): def __len__(self):
......
...@@ -79,6 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -79,6 +79,7 @@ class SubjectLoader(torch.utils.data.Dataset):
near: float = None, near: float = None,
far: float = None, far: float = None,
batch_over_images: bool = True, batch_over_images: bool = True,
device: str = "cuda:0",
): ):
super().__init__() super().__init__()
assert split in self.SPLITS, "%s" % split assert split in self.SPLITS, "%s" % split
...@@ -109,8 +110,10 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -109,8 +110,10 @@ class SubjectLoader(torch.utils.data.Dataset):
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split root_fp, subject_id, split
) )
self.images = torch.from_numpy(self.images).to(torch.uint8) self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32) self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.K = torch.tensor( self.K = torch.tensor(
[ [
[self.focal, 0, self.WIDTH / 2.0], [self.focal, 0, self.WIDTH / 2.0],
...@@ -118,6 +121,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -118,6 +121,7 @@ class SubjectLoader(torch.utils.data.Dataset):
[0, 0, 1], [0, 0, 1],
], ],
dtype=torch.float32, dtype=torch.float32,
device=device,
) # (3, 3) ) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH) assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
......
...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
import argparse import argparse
import math import math
import os import pathlib
import time import time
import imageio import imageio
...@@ -24,6 +24,12 @@ if __name__ == "__main__": ...@@ -24,6 +24,12 @@ if __name__ == "__main__":
set_random_seed(42) set_random_seed(42)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/dnerf"),
help="the root dir of the dataset",
)
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
...@@ -91,31 +97,22 @@ if __name__ == "__main__": ...@@ -91,31 +97,22 @@ if __name__ == "__main__":
gamma=0.33, gamma=0.33,
) )
# setup the dataset # setup the dataset
data_root_fp = "/home/ruilongli/data/dnerf/"
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
grid_resolution = 128 grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
) )
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
train_dataset.timestamps = train_dataset.timestamps.to(device)
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
) )
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
test_dataset.timestamps = test_dataset.timestamps.to(device)
occupancy_grid = OccupancyGrid( occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, roi_aabb=args.aabb,
...@@ -191,7 +188,7 @@ if __name__ == "__main__": ...@@ -191,7 +188,7 @@ if __name__ == "__main__":
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
) )
if step >= 0 and step % max_steps == 0 and step > 0: if step > 0 and step % max_steps == 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
......
...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
import argparse import argparse
import math import math
import os import pathlib
import time import time
import imageio import imageio
...@@ -23,6 +23,12 @@ if __name__ == "__main__": ...@@ -23,6 +23,12 @@ if __name__ == "__main__":
set_random_seed(42) set_random_seed(42)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
...@@ -112,7 +118,6 @@ if __name__ == "__main__": ...@@ -112,7 +118,6 @@ if __name__ == "__main__":
if args.scene == "garden": if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader from datasets.nerf_360_v2 import SubjectLoader
data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4} test_dataset_kwargs = {"factor": 4}
...@@ -120,32 +125,24 @@ if __name__ == "__main__": ...@@ -120,32 +125,24 @@ if __name__ == "__main__":
else: else:
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 16 target_sample_batch_size = 1 << 16
grid_resolution = 128 grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs, **train_dataset_kwargs,
) )
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
**test_dataset_kwargs, **test_dataset_kwargs,
) )
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
occupancy_grid = OccupancyGrid( occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, roi_aabb=args.aabb,
...@@ -217,7 +214,7 @@ if __name__ == "__main__": ...@@ -217,7 +214,7 @@ if __name__ == "__main__":
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
) )
if step >= 0 and step % max_steps == 0 and step > 0: if step > 0 and step % max_steps == 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
......
...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley. ...@@ -4,7 +4,7 @@ Copyright (c) 2022 Ruilong Li, UC Berkeley.
import argparse import argparse
import math import math
import os import pathlib
import time import time
import imageio import imageio
...@@ -23,6 +23,12 @@ if __name__ == "__main__": ...@@ -23,6 +23,12 @@ if __name__ == "__main__":
set_random_seed(42) set_random_seed(42)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default=str(pathlib.Path.cwd() / "data"),
help="the root dir of the dataset",
)
parser.add_argument( parser.add_argument(
"--train_split", "--train_split",
type=str, type=str,
...@@ -87,7 +93,6 @@ if __name__ == "__main__": ...@@ -87,7 +93,6 @@ if __name__ == "__main__":
if args.unbounded: if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader from datasets.nerf_360_v2 import SubjectLoader
data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 20 target_sample_batch_size = 1 << 20
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4} test_dataset_kwargs = {"factor": 4}
...@@ -95,32 +100,24 @@ if __name__ == "__main__": ...@@ -95,32 +100,24 @@ if __name__ == "__main__":
else: else:
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 18 target_sample_batch_size = 1 << 18
grid_resolution = 128 grid_resolution = 128
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples, num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs, **train_dataset_kwargs,
) )
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=data_root_fp, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
**test_dataset_kwargs, **test_dataset_kwargs,
) )
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
if args.auto_aabb: if args.auto_aabb:
camera_locs = torch.cat( camera_locs = torch.cat(
...@@ -260,7 +257,7 @@ if __name__ == "__main__": ...@@ -260,7 +257,7 @@ if __name__ == "__main__":
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
) )
if step >= 0 and step % max_steps == 0 and step > 0: if step > 0 and step % max_steps == 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment