Commit 3eef9208 authored by Ruilong Li's avatar Ruilong Li
Browse files

format

parent 88a6aec6
......@@ -276,7 +276,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -311,10 +313,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.width,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.height,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
......
......@@ -143,7 +143,9 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
color_bkgd = torch.rand(
3, device=self.images.device, generator=self.g
)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -179,10 +181,18 @@ class SubjectLoader(torch.utils.data.Dataset):
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.WIDTH,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g
0,
self.HEIGHT,
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index]
......
......@@ -24,6 +24,7 @@ from examples.utils import (
)
from nerfacc.estimators.occ_grid import OccGridEstimator
def run(args):
device = "cuda:0"
set_random_seed(42)
......@@ -102,7 +103,10 @@ def run(args):
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
radiance_field.parameters(),
lr=1e-2,
eps=1e-15,
weight_decay=weight_decay,
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
......@@ -167,7 +171,8 @@ def run(args):
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
......@@ -249,6 +254,7 @@ def run(args):
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
......@@ -274,4 +280,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
run(args)
\ No newline at end of file
run(args)
......@@ -2,13 +2,14 @@ import math
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from ..grid import _enlarge_aabb
from ..volrend import (
render_visibility_from_alpha,
render_visibility_from_density,
)
from .base import AbstractEstimator
from torch import Tensor
try:
import svox
......@@ -21,7 +22,7 @@ except ImportError:
class N3TreeEstimator(AbstractEstimator):
"""Use N3Tree to implement Occupancy Grid.
This allows more flexible topologies than the cascaded grid. However, it is
slower to create samples from the tree than the cascaded grid. By default,
it has the same topology as the cascaded grid but `self.tree` can be
......@@ -43,7 +44,9 @@ class N3TreeEstimator(AbstractEstimator):
)
# check the resolution is legal
assert isinstance(resolution, int), "N3Tree only supports uniform resolution!"
assert isinstance(
resolution, int
), "N3Tree only supports uniform resolution!"
# check the roi_aabb is legal
if isinstance(roi_aabb, (list, tuple)):
......@@ -148,16 +151,18 @@ class N3TreeEstimator(AbstractEstimator):
"""
assert t_min is None and t_max is None, (
"Do not supported per-ray min max. Please use near_plane and far_plane instead."
)
assert (
t_min is None and t_max is None
), "Do not supported per-ray min max. Please use near_plane and far_plane instead."
if stratified:
near_plane += torch.rand(()).item() * render_step_size
t_starts, t_ends, packed_info, ray_indices = svox.volume_sample(
self.tree,
thresh=self.thresh,
rays=svox.Rays(rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()),
rays=svox.Rays(
rays_o.contiguous(), rays_d.contiguous(), rays_d.contiguous()
),
step_size=render_step_size,
cone_angle=cone_angle,
near_plane=near_plane,
......@@ -253,10 +258,16 @@ class N3TreeEstimator(AbstractEstimator):
@torch.no_grad()
def _sample_uniform_and_occupied_cells(self, n: int) -> List[Tensor]:
"""Samples both n uniform and occupied cells."""
uniform_indices = torch.randint(len(self.tree), (n,), device=self.device)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[:, 0]
uniform_indices = torch.randint(
len(self.tree), (n,), device=self.device
)
occupied_indices = torch.nonzero(self.tree[:].values >= self.thresh)[
:, 0
]
if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=self.device)
selector = torch.randint(
len(occupied_indices), (n,), device=self.device
)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices
......@@ -275,7 +286,9 @@ class N3TreeEstimator(AbstractEstimator):
x = self.tree.sample(1).squeeze(1)
occ = occ_eval_fn(x).squeeze(-1)
sel = (*self.tree._all_leaves().T,)
self.tree.data.data[sel] = torch.maximum(self.tree.data.data[sel] * ema_decay, occ[:, None])
self.tree.data.data[sel] = torch.maximum(
self.tree.data.data[sel] * ema_decay, occ[:, None]
)
else:
N = len(self.tree) // 4
indices = self._sample_uniform_and_occupied_cells(N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment