Commit bc7f7fff authored by Ruilong Li's avatar Ruilong Li
Browse files

train+val set

parent 1e8f3427
...@@ -17,3 +17,4 @@ Tested with the default settings on the Lego test set. ...@@ -17,3 +17,4 @@ Tested with the default settings on the Lego test set.
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 | | instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 | | torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX | | ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX |
\ No newline at end of file
...@@ -48,7 +48,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -48,7 +48,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
class SubjectLoader(CachedIterDataset): class SubjectLoader(CachedIterDataset):
"""Single subject data loader for training and evaluation.""" """Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "test"] SPLITS = ["train", "val", "trainval", "test"]
SUBJECT_IDS = [ SUBJECT_IDS = [
"chair", "chair",
"drums", "drums",
...@@ -84,6 +84,17 @@ class SubjectLoader(CachedIterDataset): ...@@ -84,6 +84,17 @@ class SubjectLoader(CachedIterDataset):
self.far = self.FAR if far is None else far self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train"]) self.training = (num_rays is not None) and (split in ["train"])
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings(
root_fp, subject_id, "train"
)
_images_val, _camtoworlds_val, _focal_val = _load_renderings(
root_fp, subject_id, "val"
)
self.images = np.concatenate([_images_train, _images_val])
self.camtoworlds = np.concatenate([_camtoworlds_train, _camtoworlds_val])
self.focal = _focal_train
else:
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split root_fp, subject_id, split
) )
......
...@@ -64,7 +64,7 @@ if __name__ == "__main__": ...@@ -64,7 +64,7 @@ if __name__ == "__main__":
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="train", split="val",
num_rays=8192, num_rays=8192,
) )
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
...@@ -73,14 +73,14 @@ if __name__ == "__main__": ...@@ -73,14 +73,14 @@ if __name__ == "__main__":
batch_size=1, batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"), collate_fn=getattr(train_dataset.__class__, "collate_fn"),
) )
val_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test", split="test",
num_rays=None, num_rays=None,
) )
val_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
val_dataset, test_dataset,
num_workers=10, num_workers=10,
batch_size=1, batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"), collate_fn=getattr(train_dataset.__class__, "collate_fn"),
...@@ -125,7 +125,7 @@ if __name__ == "__main__": ...@@ -125,7 +125,7 @@ if __name__ == "__main__":
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
for epoch in range(100): for epoch in range(200):
for data in train_dataloader: for data in train_dataloader:
step += 1 step += 1
if step > 30_000: if step > 30_000:
...@@ -162,7 +162,7 @@ if __name__ == "__main__": ...@@ -162,7 +162,7 @@ if __name__ == "__main__":
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
with torch.no_grad(): with torch.no_grad():
for data in tqdm.tqdm(val_dataloader): for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color # generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"]) rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device) pixels = data["pixels"].to(device)
...@@ -177,5 +177,10 @@ if __name__ == "__main__": ...@@ -177,5 +177,10 @@ if __name__ == "__main__":
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}") print(f"evaluation: {psnr_avg=}")
# "train"
# elapsed_time=317.59s | step=30000 | loss= 0.00028 # elapsed_time=317.59s | step=30000 | loss= 0.00028
# evaluation: psnr_avg=33.27096959114075 (6.24 it/s) # evaluation: psnr_avg=33.27096959114075 (6.24 it/s)
# "trainval"
# elapsed_time=389.08s | step=30000 | loss= 0.00030
# evaluation: psnr_avg=34.00573859214783 (6.26 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment