Unverified Commit 9a1c0ab5 authored by Jingchen Ye's avatar Jingchen Ye Committed by GitHub
Browse files

Fix train_mlp_nerf and save the model at the end of training (#177)

* Fix train_mlp_nerf

* Fix black and isort
parent 17e28de7
......@@ -13,7 +13,12 @@ import torch
import torch.nn.functional as F
import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image,
set_random_seed,
)
from nerfacc import ContractionType, OccupancyGrid
......@@ -34,23 +39,17 @@ parser.add_argument(
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
help="the path of the pretrained model",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
......@@ -74,11 +73,47 @@ args = parser.parse_args()
render_n_samples = 1024
# setup the scene bounding box.
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader
print("Using unbounded rendering")
target_sample_batch_size = 1 << 16
train_dataset_kwargs["color_bkgd_aug"] = "random"
train_dataset_kwargs["factor"] = 4
test_dataset_kwargs["factor"] = 4
grid_resolution = 128
elif args.scene in NERF_SYNTHETIC_SCENES:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
......@@ -110,44 +145,22 @@ scheduler = torch.optim.lr_scheduler.MultiStepLR(
gamma=0.33,
)
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader
target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader
target_sample_batch_size = 1 << 16
grid_resolution = 128
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)
if args.model_path is not None:
checkpoint = torch.load(args.model_path)
radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"])
step = checkpoint["step"]
else:
step = 0
# training
step = 0
tic = time.time()
......@@ -204,14 +217,28 @@ for epoch in range(10000000):
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"psnr={psnr:.2f}"
)
if step > 0 and step % max_steps == 0:
model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"occupancy_grid_state_dict": occupancy_grid.state_dict(),
},
model_save_path,
)
# evaluation
radiance_field.eval()
......@@ -230,8 +257,8 @@ for epoch in range(10000000):
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
......@@ -246,7 +273,7 @@ for epoch in range(10000000):
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# f"rgb_test_{i}.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment