Commit 3eef9208 authored by Ruilong Li's avatar Ruilong Li
Browse files

format

parent 88a6aec6
......@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.width,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.height,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
......
......@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.WIDTH,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.HEIGHT,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
......
......@@ -24,6 +24,7 @@ from examples.utils import (
)
from nerfacc.estimators.occ_grid import OccGridEstimator
def run(args):
device = "cuda:0"
set_random_seed(42)
......@@ -102,7 +103,10 @@ def run(args):
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
......@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
......@@ -249,6 +254,7 @@ def run(args):
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
......
......@@ -2,13 +2,14 @@ import math
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from ..grid import _enlarge_aabb
from ..volrend import (
render_visibility_from_alpha,
render_visibility_from_density,
)
from .base import AbstractEstimator
from torch import Tensor
try:
import svox
......@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
)
# check the resolution is legal
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!"
assert isinstance(
resolution, int
), "N3Tree only supports uniform resolution!"
# check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)):
......@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
"""
assert t_min is None and t_max is None, (
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
)
assert (
t_min is None and t_max is None
), "Do not supported per-ray min max. Please use near_plane and far_plane instead."
if stratified:
near_plane += torch.rand(()).item() * render_step_size
t_starts, t_ends, packed_info, ray_indices = svox.volume_sample(
self.tree,
thresh=self.thresh,
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()),
rays=svox.Rays(
rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()
),
step_size=render_step_size,
cone_angle=cone_angle,
near_plane=near_plane,
......@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells."""
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0]
uniform_indices = torch.randint(
len(self.tree), (n,), device=self.device
)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[
:, 0
]
if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=self.device)
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices
......@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x = self.tree.sample(1).squeeze(1)
occ = occ_eval_fn(x).squeeze(-1)
sel = (*self.tree._all_leaves().T,)
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None])
self.tree.data.data[sel] = torch.maximum(
self.tree.data.data[sel] * ema_decay, occ[:, None]
)
else:
N = len(self.tree) // 4
indices = self._sample_uniform_and_occupied_cells(N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment