Commit 477ff3df authored by Ruilong Li's avatar Ruilong Li
Browse files

benchmark: 7k, 345s, 34db, 56k rays, loss 0.0011

parent b42d62cc
...@@ -201,7 +201,7 @@ if __name__ == "__main__": ...@@ -201,7 +201,7 @@ if __name__ == "__main__":
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/" data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 20 target_sample_batch_size = 1 << 18
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
...@@ -255,7 +255,7 @@ if __name__ == "__main__": ...@@ -255,7 +255,7 @@ if __name__ == "__main__":
* math.sqrt(3) * math.sqrt(3)
/ render_n_samples / render_n_samples
).item() ).item()
alpha_thre = 0.0 alpha_thre = 1e-2
proposal_nets = torch.nn.ModuleList( proposal_nets = torch.nn.ModuleList(
[ [
...@@ -330,7 +330,7 @@ if __name__ == "__main__": ...@@ -330,7 +330,7 @@ if __name__ == "__main__":
render_step_size=render_step_size, render_step_size=render_step_size,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
alpha_thre=alpha_thre, alpha_thre=min(alpha_thre, alpha_thre * step / 1000),
) )
if n_rendering_samples == 0: if n_rendering_samples == 0:
continue continue
...@@ -374,7 +374,7 @@ if __name__ == "__main__": ...@@ -374,7 +374,7 @@ if __name__ == "__main__":
torch.clamp(proposal_weights_gt - proposal_weights, min=0) torch.clamp(proposal_weights_gt - proposal_weights, min=0)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps) ) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean() loss_interval = loss_interval.mean()
loss += loss_interval * 1.0 loss += loss_interval * 0.1
optimizer.zero_grad() optimizer.zero_grad()
# do not unscale it because we are using Adam. # do not unscale it because we are using Adam.
......
...@@ -196,7 +196,7 @@ def ray_marching( ...@@ -196,7 +196,7 @@ def ray_marching(
proposal_sample_list = [] proposal_sample_list = []
if proposal_nets is not None: if proposal_nets is not None:
# resample with proposal nets # resample with proposal nets
for net, num_samples in zip(proposal_nets, [64]): for net, num_samples in zip(proposal_nets, [32]):
with torch.no_grad(): with torch.no_grad():
# skip invisible space # skip invisible space
if sigma_fn is not None or alpha_fn is not None: if sigma_fn is not None or alpha_fn is not None:
...@@ -226,7 +226,7 @@ def ray_marching( ...@@ -226,7 +226,7 @@ def ray_marching(
alphas, alphas,
ray_indices=ray_indices, ray_indices=ray_indices,
early_stop_eps=early_stop_eps, early_stop_eps=early_stop_eps,
alpha_thre=min(alphas.mean().item(), 1e-1), alpha_thre=alpha_thre,
n_rays=rays_o.shape[0], n_rays=rays_o.shape[0],
) )
ray_indices, t_starts, t_ends = ( ray_indices, t_starts, t_ends = (
...@@ -255,6 +255,40 @@ def ray_marching( ...@@ -255,6 +255,40 @@ def ray_marching(
) )
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0]) ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
with torch.no_grad():
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(
sigmas.shape
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
if proposal_nets is not None: if proposal_nets is not None:
return ray_indices, t_starts, t_ends, proposal_sample_list return ray_indices, t_starts, t_ends, proposal_sample_list
else: else:
......
...@@ -142,7 +142,7 @@ def accumulate_along_rays( ...@@ -142,7 +142,7 @@ def accumulate_along_rays(
Args: Args:
weights: Volumetric rendering weights for those samples. Tensor with shape \ weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,). (n_samples, 1).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). ray_indices: Ray index of each sample. IntTensor with shape (n_samples).
values: The values to be accmulated. Tensor with shape (n_samples, D). If \ values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None. None, the accumulated values are just weights. Default is None.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment