Commit 3eef9208 authored by Ruilong Li's avatar Ruilong Li
Browse files

format

parent 88a6aec6
...@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training: if self.training:
if self.color_bkgd_aug == "random": if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white": elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device) color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black": elif self.color_bkgd_aug == "black":
...@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else: else:
image_id = [index] * num_rays image_id = [index] * num_rays
x = torch.randint( x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.width,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
y = torch.randint( y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.height,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
else: else:
image_id = [index] image_id = [index]
......
...@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training: if self.training:
if self.color_bkgd_aug == "random": if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g) color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white": elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device) color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black": elif self.color_bkgd_aug == "black":
...@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else: else:
image_id = [index] * num_rays image_id = [index] * num_rays
x = torch.randint( x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.WIDTH,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
y = torch.randint( y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g 0,
self.HEIGHT,
size=(num_rays,),
device=self.images.device,
generator=self.g,
) )
else: else:
image_id = [index] image_id = [index]
......
...@@ -24,6 +24,7 @@ from examples.utils import ( ...@@ -24,6 +24,7 @@ from examples.utils import (
) )
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
def run(args): def run(args):
device = "cuda:0" device = "cuda:0"
set_random_seed(42) set_random_seed(42)
...@@ -102,7 +103,10 @@ def run(args): ...@@ -102,7 +103,10 @@ def run(args):
grad_scaler = torch.cuda.amp.GradScaler(2**10) grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
) )
scheduler = torch.optim.lr_scheduler.ChainedScheduler( scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[ [
...@@ -167,7 +171,8 @@ def run(args): ...@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
num_rays = int( num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples)) num_rays
* (target_sample_batch_size / float(n_rendering_samples))
) )
train_dataset.update_num_rays(num_rays) train_dataset.update_num_rays(num_rays)
...@@ -249,6 +254,7 @@ def run(args): ...@@ -249,6 +254,7 @@ def run(args):
lpips_avg = sum(lpips) / len(lpips) lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}") print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
......
...@@ -2,13 +2,14 @@ import math ...@@ -2,13 +2,14 @@ import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from ..grid import _enlarge_aabb from ..grid import _enlarge_aabb
from ..volrend import ( from ..volrend import (
render_visibility_from_alpha, render_visibility_from_alpha,
render_visibility_from_density, render_visibility_from_density,
) )
from .base import AbstractEstimator from .base import AbstractEstimator
from torch import Tensor
try: try:
import svox import svox
...@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
) )
# check the resolution is legal # check the resolution is legal
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!" assert isinstance(
resolution, int
), "N3Tree only supports uniform resolution!"
# check the roi_aabb is legal # check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)): if isinstance(roi_aabb, (list, tuple)):
...@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
""" """
assert t_min is None and t_max is None, ( assert (
"Do not supported per-ray min max. Please use near_plane and far_plane instead." t_min is None and t_max is None
) ), "Do not supported per-ray min max. Please use near_plane and far_plane instead."
if stratified: if stratified:
near_plane += torch.rand(()).item() * render_step_size near_plane += torch.rand(()).item() * render_step_size
t_starts, t_ends, packed_info, ray_indices = svox.volume_sample( t_starts, t_ends, packed_info, ray_indices = svox.volume_sample(
self.tree, self.tree,
thresh=self.thresh, thresh=self.thresh,
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()), rays=svox.Rays(
rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()
),
step_size=render_step_size, step_size=render_step_size,
cone_angle=cone_angle, cone_angle=cone_angle,
near_plane=near_plane, near_plane=near_plane,
...@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@torch.no_grad() @torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]: def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells.""" """Samples both n uniform and occupied cells."""
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device) uniform_indices = torch.randint(
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0] len(self.tree), (n,), device=self.device
)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[
:, 0
]
if n < len(occupied_indices): if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=self.device) selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector] occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0) indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices return indices
...@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator): ...@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x = self.tree.sample(1).squeeze(1) x = self.tree.sample(1).squeeze(1)
occ = occ_eval_fn(x).squeeze(-1) occ = occ_eval_fn(x).squeeze(-1)
sel = (*self.tree._all_leaves().T,) sel = (*self.tree._all_leaves().T,)
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None]) self.tree.data.data[sel] = torch.maximum(
self.tree.data.data[sel] * ema_decay, occ[:, None]
)
else: else:
N = len(self.tree) // 4 N = len(self.tree) // 4
indices = self._sample_uniform_and_occupied_cells(N) indices = self._sample_uniform_and_occupied_cells(N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment