Unverified Commit 9a1c0ab5 authored by Jingchen Ye's avatar Jingchen Ye Committed by GitHub
Browse files

Fix train_mlp_nerf and save the model at the end of training (#177)

* Fix train_mlp_nerf

* Fix black and isort
parent 17e28de7
...@@ -13,7 +13,12 @@ import torch ...@@ -13,7 +13,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image,
set_random_seed,
)
from nerfacc import ContractionType, OccupancyGrid from nerfacc import ContractionType, OccupancyGrid
...@@ -34,23 +39,17 @@ parser.add_argument( ...@@ -34,23 +39,17 @@ parser.add_argument(
choices=["train", "trainval"], choices=["train", "trainval"],
help="which train split to use", help="which train split to use",
) )
parser.add_argument(
"--model_path",
type=str,
default=None,
help="the path of the pretrained model",
)
parser.add_argument( parser.add_argument(
"--scene", "--scene",
type=str, type=str,
default="lego", default="lego",
choices=[ choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
help="which scene to use", help="which scene to use",
) )
parser.add_argument( parser.add_argument(
...@@ -74,11 +73,47 @@ args = parser.parse_args() ...@@ -74,11 +73,47 @@ args = parser.parse_args()
render_n_samples = 1024 render_n_samples = 1024
# setup the scene bounding box. # setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
print("Using unbounded rendering")
target_sample_batch_size = 1 << 16
train_dataset_kwargs["color_bkgd_aug"] = "random"
train_dataset_kwargs["factor"] = 4
test_dataset_kwargs["factor"] = 4
grid_resolution = 128
elif args.scene in NERF_SYNTHETIC_SCENES:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
if args.unbounded: if args.unbounded:
print("Using unbounded rendering") print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None scene_aabb = None
near_plane = 0.2 near_plane = 0.2
far_plane = 1e4 far_plane = 1e4
...@@ -110,44 +145,22 @@ scheduler = torch.optim.lr_scheduler.MultiStepLR( ...@@ -110,44 +145,22 @@ scheduler = torch.optim.lr_scheduler.MultiStepLR(
gamma=0.33, gamma=0.33,
) )
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
occupancy_grid = OccupancyGrid( occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb, roi_aabb=args.aabb,
resolution=grid_resolution, resolution=grid_resolution,
contraction_type=contraction_type, contraction_type=contraction_type,
).to(device) ).to(device)
if args.model_path is not None:
checkpoint = torch.load(args.model_path)
radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"])
step = checkpoint["step"]
else:
step = 0
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
...@@ -204,14 +217,28 @@ for epoch in range(10000000): ...@@ -204,14 +217,28 @@ for epoch in range(10000000):
if step % 5000 == 0: if step % 5000 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print( print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | " f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"psnr={psnr:.2f}"
) )
if step > 0 and step % max_steps == 0: if step > 0 and step % max_steps == 0:
model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"occupancy_grid_state_dict": occupancy_grid.state_dict(),
},
model_save_path,
)
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
...@@ -230,8 +257,8 @@ for epoch in range(10000000): ...@@ -230,8 +257,8 @@ for epoch in range(10000000):
rays, rays,
scene_aabb, scene_aabb,
# rendering options # rendering options
near_plane=None, near_plane=near_plane,
far_plane=None, far_plane=far_plane,
render_step_size=render_step_size, render_step_size=render_step_size,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
...@@ -246,7 +273,7 @@ for epoch in range(10000000): ...@@ -246,7 +273,7 @@ for epoch in range(10000000):
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# ) # )
# imageio.imwrite( # imageio.imwrite(
# "rgb_test.png", # f"rgb_test_{i}.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8), # (rgb.cpu().numpy() * 255).astype(np.uint8),
# ) # )
# break # break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment