Commit 477ff3df authored by Ruilong Li's avatar Ruilong Li
Browse files

benchmark: 7k, 345s, 34db, 56k rays, loss 0.0011

parent b42d62cc
......@@ -201,7 +201,7 @@ if __name__ == "__main__":
from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 20
target_sample_batch_size = 1 << 18
train_dataset = SubjectLoader(
subject_id=args.scene,
......@@ -255,7 +255,7 @@ if __name__ == "__main__":
* math.sqrt(3)
/ render_n_samples
).item()
alpha_thre = 0.0
alpha_thre = 1e-2
proposal_nets = torch.nn.ModuleList(
[
......@@ -330,7 +330,7 @@ if __name__ == "__main__":
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
alpha_thre=min(alpha_thre, alpha_thre * step / 1000),
)
if n_rendering_samples == 0:
continue
......@@ -374,7 +374,7 @@ if __name__ == "__main__":
torch.clamp(proposal_weights_gt - proposal_weights, min=0)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean()
loss += loss_interval * 1.0
loss += loss_interval * 0.1
optimizer.zero_grad()
# do not unscale it because we are using Adam.
......
......@@ -196,7 +196,7 @@ def ray_marching(
proposal_sample_list = []
if proposal_nets is not None:
# resample with proposal nets
for net, num_samples in zip(proposal_nets, [64]):
for net, num_samples in zip(proposal_nets, [32]):
with torch.no_grad():
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
......@@ -226,7 +226,7 @@ def ray_marching(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=min(alphas.mean().item(), 1e-1),
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
......@@ -255,6 +255,40 @@ def ray_marching(
)
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
with torch.no_grad():
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(
sigmas.shape
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
if proposal_nets is not None:
return ray_indices, t_starts, t_ends, proposal_sample_list
else:
......
......@@ -142,7 +142,7 @@ def accumulate_along_rays(
Args:
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
(n_samples, 1).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples).
values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment