Commit 3eef9208 authored by Ruilong Li's avatar Ruilong Li
Browse files

format

parent 88a6aec6
...@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training: if self.training:
if self.color_bkgd_aug == "random": if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white": elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device) color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black": elif self.color_bkgd_aug == "black":
...@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else: else:
image_id = [index] * num_rays image_id = [index] * num_rays
x = torch.randint( x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.width,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
y = torch.randint( y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.height,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
else: else:
image_id = [index] image_id = [index]
......
...@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training: if self.training:
if self.color_bkgd_aug == "random": if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white": elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device) color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black": elif self.color_bkgd_aug == "black":
...@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else: else:
image_id = [index] * num_rays image_id = [index] * num_rays
x = torch.randint( x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.WIDTH,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
y = torch.randint( y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.HEIGHT,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
else: else:
image_id = [index] image_id = [index]
......
...@@ -24,6 +24,7 @@ from examples.utils import ( ...@@ -24,6 +24,7 @@ from examples.utils import (
) )
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
def run(args): def run(args):
device = "cuda:0" device = "cuda:0"
set_random_seed(42) set_random_seed(42)
...@@ -102,7 +103,10 @@ def run(args): ...@@ -102,7 +103,10 @@ def run(args):
grad_scaler = torch.cuda.amp.GradScaler(2**10) grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
) )
scheduler = torch.optim.lr_scheduler.ChainedScheduler( scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[ [
...@@ -167,7 +171,8 @@ def run(args): ...@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples)) num_rays
* (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
...@@ -249,6 +254,7 @@ def run(args): ...@@ -249,6 +254,7 @@ def run(args):
lpips_avg = sum(lpips) / len(lpips) lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}") print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -274,4 +280,4 @@ if __name__ == "__main__": ...@@ -274,4 +280,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
run(args) run(args)
\ No newline at end of file
...@@ -2,13 +2,14 @@ import math ...@@ -2,13 +2,14 @@ import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from ..grid import _enlarge_aabb from ..grid import _enlarge_aabb
from ..volrend import ( from ..volrend import (
render_visibility_from_alpha, render_visibility_from_alpha,
render_visibility_from_density, render_visibility_from_density,
) )
from .base import AbstractEstimator from .base import AbstractEstimator
from torch import Tensor
try: try:
import svox import svox
...@@ -21,7 +22,7 @@ except ImportError: ...@@ -21,7 +22,7 @@ except ImportError:
class N3TreeEstimator(AbstractEstimator): class N3TreeEstimator(AbstractEstimator):
"""Use N3Tree to implement Occupancy Grid. """Use N3Tree to implement Occupancy Grid.
This allows more flexible topologies than the cascaded grid. However, it is This allows more flexible topologies than the cascaded grid. However, it is
slower to create samples from the tree than the cascaded grid. By default, slower to create samples from the tree than the cascaded grid. By default,
it has the same topology as the cascaded grid but `self.tree` can be it has the same topology as the cascaded grid but `self.tree` can be
...@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
) )
# check the resolution is legal # check the resolution is legal
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!" assert isinstance(
resolution, int
), "N3Tree only supports uniform resolution!"
# check the roi_aabb is legal # check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)): if isinstance(roi_aabb, (list, tuple)):
...@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
""" """
assert t_min is None and t_max is None, ( assert (
"Do not supported per-ray min max. Please use near_plane and far_plane instead." t_min is None and t_max is None
) ), "Do not supported per-ray min max. Please use near_plane and far_plane instead."
if stratified: if stratified:
near_plane += torch.rand(()).item() * render_step_size near_plane += torch.rand(()).item() * render_step_size
t_starts, t_ends, packed_info, ray_indices = svox.volume_sample( t_starts, t_ends, packed_info, ray_indices = svox.volume_sample(
self.tree, self.tree,
thresh=self.thresh, thresh=self.thresh,
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()), rays=svox.Rays(
rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()
),
step_size=render_step_size, step_size=render_step_size,
cone_angle=cone_angle, cone_angle=cone_angle,
near_plane=near_plane, near_plane=near_plane,
...@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@torch.no_grad() @torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]: def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells.""" """Samples both n uniform and occupied cells."""
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device) uniform_indices = torch.randint(
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0] len(self.tree), (n,), device=self.device
)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[
:, 0
]
if n < len(occupied_indices): if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=self.device) selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector] occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0) indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices return indices
...@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x = self.tree.sample(1).squeeze(1) x = self.tree.sample(1).squeeze(1)
occ = occ_eval_fn(x).squeeze(-1) occ = occ_eval_fn(x).squeeze(-1)
sel = (*self.tree._all_leaves().T,) sel = (*self.tree._all_leaves().T,)
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None]) self.tree.data.data[sel] = torch.maximum(
self.tree.data.data[sel] * ema_decay, occ[:, None]
)
else: else:
N = len(self.tree) // 4 N = len(self.tree) // 4
indices = self._sample_uniform_and_occupied_cells(N) indices = self._sample_uniform_and_occupied_cells(N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment