Commit 3d011a91 authored by Emilien Garreau's avatar Emilien Garreau Committed by Facebook GitHub Bot
Browse files

Adapt RayPointRefiner and RayMarcher to support bins.

Summary:
## Context

Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.

We introduce here the support of bins like in MipNerf implementation.

## RayPointRefiner

Small changes have been made to modify RayPointRefiner.
- If bins is None

```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
		mids, # [..., npt]
		weights[..., 1:-1], # [..., npt - 1]
               ….
            )
```

- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.

```
z_samples = sample_pdf(
		ray_bundle.bins, # [..., npt + 1]
		weights, # [..., npt]
               ...
            )
```

## RayMarcher

Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.

Reviewed By: shapovalov

Differential Revision: D46389092

fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
parent 5910d81b
......@@ -157,9 +157,13 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
else 0.0
)
ray_deltas = (
None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1)
)
output = self.raymarcher(
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
ray_deltas=ray_deltas,
density_noise_std=density_noise_std,
)
output.prev_stage = prev_stage
......
......@@ -78,19 +78,28 @@ class RayPointRefiner(Configurable, torch.nn.Module):
"""
z_vals = input_ray_bundle.lengths
with torch.no_grad():
if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights)
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
n_pts_per_ray = self.n_pts_per_ray
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
if input_ray_bundle.bins is None:
z_vals: torch.Tensor = input_ray_bundle.lengths
ray_weights = ray_weights[..., 1:-1]
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
else:
z_vals = input_ray_bundle.bins
n_pts_per_ray += 1
bins = z_vals
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
bins.view(-1, bins.shape[-1]),
ray_weights,
n_pts_per_ray,
det=not self.random_sampling,
eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
).view(*z_vals.shape[:-1], n_pts_per_ray)
if self.add_input_samples:
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
......@@ -98,9 +107,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle
kwargs_ray = dict(vars(input_ray_bundle))
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
def apply_blurpool_on_weights(weights) -> torch.Tensor:
......
......@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.renderer.base import RendererOutput
......@@ -119,6 +119,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
rays_features: torch.Tensor,
aux: Dict[str, Any],
ray_lengths: torch.Tensor,
ray_deltas: Optional[torch.Tensor] = None,
density_noise_std: float = 0.0,
**kwargs,
) -> RendererOutput:
......@@ -131,6 +132,9 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
ray_deltas: Optional differences between consecutive elements along the ray bundle
represented with a tensor of shape `(..., n_points_per_ray)`. If None,
these differences are computed from ray_lengths.
density_noise_std: the magnitude of the noise added to densities.
Returns:
......@@ -152,14 +156,17 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
density_1d=True,
)
ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1]
if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
if ray_deltas is None:
ray_lengths_diffs = torch.diff(ray_lengths, dim=-1)
if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
else:
last_interval = torch.full_like(
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
else:
last_interval = torch.full_like(
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
deltas = ray_deltas
rays_densities = rays_densities[..., 0]
......
......@@ -24,7 +24,7 @@ class HarmonicEmbedding(torch.nn.Module):
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
......
......@@ -70,6 +70,71 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_simple_use_bins(self):
"""
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
It has been duplicated to avoid cognitive overload while reading the
test (lot of if else).
"""
length = 15
n_pts_per_ray = 10
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
bundle = ImplicitronRayBundle(
lengths=None,
bins=torch.arange(length + 1, dtype=torch.float32).expand(
3, 25, length + 1
),
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
if add_input_samples:
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
0
]
full_expected = torch.lerp(
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
)
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(
bundle, weights, blurpool_weights=use_blurpool
)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0)
self.assertLess(lengths_random.max().item(), length)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment