"src/runtime/vscode:/vscode.git/clone" did not exist on "166b273becae0438d0114daf5346b80ebf1a4333"
Commit ad98221d authored by Ruilong Li's avatar Ruilong Li
Browse files

distortion loss

parent 03a8a782
...@@ -5,6 +5,7 @@ from .data_specs import RayIntervals, RaySamples ...@@ -5,6 +5,7 @@ from .data_specs import RayIntervals, RaySamples
from .estimators.occ_grid import OccGridEstimator from .estimators.occ_grid import OccGridEstimator
from .estimators.prop_net import PropNetEstimator from .estimators.prop_net import PropNetEstimator
from .grid import ray_aabb_intersect, traverse_grids from .grid import ray_aabb_intersect, traverse_grids
from .losses import distortion
from .pack import pack_info from .pack import pack_info
from .pdf import importance_sampling, searchsorted from .pdf import importance_sampling, searchsorted
from .scan import exclusive_prod, exclusive_sum, inclusive_prod, inclusive_sum from .scan import exclusive_prod, exclusive_sum, inclusive_prod, inclusive_sum
...@@ -47,4 +48,5 @@ __all__ = [ ...@@ -47,4 +48,5 @@ __all__ = [
"traverse_grids", "traverse_grids",
"OccGridEstimator", "OccGridEstimator",
"PropNetEstimator", "PropNetEstimator",
"distortion",
] ]
from torch import Tensor
from .scan import inclusive_sum
from .volrend import accumulate_along_rays
def distortion(
weights: Tensor,
t_starts: Tensor,
t_ends: Tensor,
ray_indices: Tensor,
n_rays: int,
) -> Tensor:
"""Distortion Regularization proposed in Mip-NeRF 360 (on a single GPU).
Args:
weights: [n_samples,] The weights of the samples.
t_starts: [n_samples,] The start points of the samples.
t_ends: [n_samples,] The end points of the samples.
ray_indices: [n_samples,] The ray indices of the samples.
n_rays: The total number of rays.
Returns:
The per-ray distortion loss with the shape of [n_rays, 1].
"""
assert (
weights.shape == t_starts.shape == t_ends.shape == ray_indices.shape
), (
f"the shape of the inputs are not the same: "
f"weights {weights.shape}, t_starts {t_starts.shape}, "
f"t_ends {t_ends.shape}, ray_indices {ray_indices.shape}"
)
t_mids = 0.5 * (t_starts + t_ends)
t_deltas = t_ends - t_starts
loss_uni = (1 / 3) * (t_deltas * weights.pow(2))
loss_bi_0 = weights * t_mids * inclusive_sum(weights, indices=ray_indices)
loss_bi_1 = weights * inclusive_sum(weights * t_mids, indices=ray_indices)
loss_bi = 2 * (loss_bi_0 - loss_bi_1)
loss = loss_uni + loss_bi
loss = accumulate_along_rays(loss, None, ray_indices, n_rays)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment