"docs/vscode:/vscode.git/clone" did not exist on "840f236654e3ef736bdd53bd6b47d7dc5ee07423"
Commit bc7f7fff authored by Ruilong Li's avatar Ruilong Li
Browse files

train+val set

parent 1e8f3427
......@@ -17,3 +17,4 @@ Tested with the default settings on the Lego test set.
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 33.27 | 318 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 34.01 | 389 sec | 6.3 fps | TITAN RTX |
\ No newline at end of file
......@@ -48,7 +48,7 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
class SubjectLoader(CachedIterDataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "test"]
SPLITS = ["train", "val", "trainval", "test"]
SUBJECT_IDS = [
"chair",
"drums",
......@@ -84,6 +84,17 @@ class SubjectLoader(CachedIterDataset):
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train"])
self.color_bkgd_aug = color_bkgd_aug
if split == "trainval":
_images_train, _camtoworlds_train, _focal_train = _load_renderings(
root_fp, subject_id, "train"
)
_images_val, _camtoworlds_val, _focal_val = _load_renderings(
root_fp, subject_id, "val"
)
self.images = np.concatenate([_images_train, _images_val])
self.camtoworlds = np.concatenate([_camtoworlds_train, _camtoworlds_val])
self.focal = _focal_train
else:
self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split
)
......
......@@ -64,7 +64,7 @@ if __name__ == "__main__":
train_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="train",
split="val",
num_rays=8192,
)
train_dataloader = torch.utils.data.DataLoader(
......@@ -73,14 +73,14 @@ if __name__ == "__main__":
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
)
val_dataset = SubjectLoader(
test_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test",
num_rays=None,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=10,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
......@@ -125,7 +125,7 @@ if __name__ == "__main__":
# training
step = 0
tic = time.time()
for epoch in range(100):
for epoch in range(200):
for data in train_dataloader:
step += 1
if step > 30_000:
......@@ -162,7 +162,7 @@ if __name__ == "__main__":
radiance_field.eval()
psnrs = []
with torch.no_grad():
for data in tqdm.tqdm(val_dataloader):
for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
......@@ -177,5 +177,10 @@ if __name__ == "__main__":
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
# "train"
# elapsed_time=317.59s | step=30000 | loss= 0.00028
# evaluation: psnr_avg=33.27096959114075 (6.24 it/s)
# "trainval"
# elapsed_time=389.08s | step=30000 | loss= 0.00030
# evaluation: psnr_avg=34.00573859214783 (6.26 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment