Commit 3eb42338 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

New raysamplers

Summary: New MultinomialRaysampler succeeds GridRaysampler bringing masking and subsampling. Correspondingly, NDCMultinomialRaysampler succeeds NDCGridRaysampler.

Reviewed By: nikhilaravi, shapovalov

Differential Revision: D33256897

fbshipit-source-id: cd80ec6f35b110d1d20a75c62f4e889ba8fa5d45
parent 174738c3
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from typing import Tuple from typing import Tuple
import torch import torch
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points, HarmonicEmbedding from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points
from .linear_with_repeat import LinearWithRepeat from .linear_with_repeat import LinearWithRepeat
......
...@@ -32,7 +32,9 @@ from .implicit import ( ...@@ -32,7 +32,9 @@ from .implicit import (
HarmonicEmbedding, HarmonicEmbedding,
ImplicitRenderer, ImplicitRenderer,
MonteCarloRaysampler, MonteCarloRaysampler,
MultinomialRaysampler,
NDCGridRaysampler, NDCGridRaysampler,
NDCMultinomialRaysampler,
RayBundle, RayBundle,
VolumeRenderer, VolumeRenderer,
VolumeSampler, VolumeSampler,
......
...@@ -6,7 +6,13 @@ ...@@ -6,7 +6,13 @@
from .harmonic_embedding import HarmonicEmbedding from .harmonic_embedding import HarmonicEmbedding
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler from .raysampling import (
GridRaysampler,
MonteCarloRaysampler,
MultinomialRaysampler,
NDCGridRaysampler,
NDCMultinomialRaysampler,
)
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import ( from .utils import (
RayBundle, RayBundle,
...@@ -14,4 +20,5 @@ from .utils import ( ...@@ -14,4 +20,5 @@ from .utils import (
ray_bundle_variables_to_ray_points, ray_bundle_variables_to_ray_points,
) )
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
This diff is collapsed.
...@@ -10,9 +10,9 @@ from fvcore.common.benchmark import benchmark ...@@ -10,9 +10,9 @@ from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import ( from pytorch3d.renderer import (
FoVOrthographicCameras, FoVOrthographicCameras,
FoVPerspectiveCameras, FoVPerspectiveCameras,
GridRaysampler,
MonteCarloRaysampler, MonteCarloRaysampler,
NDCGridRaysampler, MultinomialRaysampler,
NDCMultinomialRaysampler,
OrthographicCameras, OrthographicCameras,
PerspectiveCameras, PerspectiveCameras,
) )
...@@ -21,7 +21,11 @@ from test_raysampling import TestRaysampling ...@@ -21,7 +21,11 @@ from test_raysampling import TestRaysampling
def bm_raysampling() -> None: def bm_raysampling() -> None:
case_grid = { case_grid = {
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler], "raysampler_type": [
MultinomialRaysampler,
NDCMultinomialRaysampler,
MonteCarloRaysampler,
],
"camera_type": [ "camera_type": [
PerspectiveCameras, PerspectiveCameras,
OrthographicCameras, OrthographicCameras,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import unittest import unittest
from numbers import Real
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
...@@ -190,3 +191,13 @@ class TestCaseMixin(unittest.TestCase): ...@@ -190,3 +191,13 @@ class TestCaseMixin(unittest.TestCase):
if msg is not None: if msg is not None:
self.fail(f"{msg} {err}") self.fail(f"{msg} {err}")
self.fail(err) self.fail(err)
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
"""
Asserts input is entirely filled with value.
Args:
input: tensor or array
"""
self.assertEqual(input.min(), value)
self.assertEqual(input.max(), value)
...@@ -5,17 +5,27 @@ ...@@ -5,17 +5,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import unittest import unittest
from typing import Callable
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.ops import eyes from pytorch3d.ops import eyes
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler from pytorch3d.renderer import (
MonteCarloRaysampler,
MultinomialRaysampler,
NDCGridRaysampler,
NDCMultinomialRaysampler,
)
from pytorch3d.renderer.cameras import ( from pytorch3d.renderer.cameras import (
FoVOrthographicCameras, FoVOrthographicCameras,
FoVPerspectiveCameras, FoVPerspectiveCameras,
OrthographicCameras, OrthographicCameras,
PerspectiveCameras, PerspectiveCameras,
) )
from pytorch3d.renderer.implicit.raysampling import (
_jiggle_within_stratas,
_safe_multinomial,
)
from pytorch3d.renderer.implicit.utils import ( from pytorch3d.renderer.implicit.utils import (
ray_bundle_to_ray_points, ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points, ray_bundle_variables_to_ray_points,
...@@ -93,14 +103,16 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -93,14 +103,16 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def raysampler( def raysampler(
raysampler_type=GridRaysampler, raysampler_type,
camera_type=PerspectiveCameras, camera_type,
n_pts_per_ray=10, n_pts_per_ray: int,
batch_size=1, batch_size: int,
image_width=10, image_width: int,
image_height=20, image_height: int,
): ) -> Callable[[], None]:
"""
Used for benchmarks.
"""
device = torch.device("cuda") device = torch.device("cuda")
# init raysamplers # init raysamplers
...@@ -120,7 +132,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -120,7 +132,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
# init a batch of random cameras # init a batch of random cameras
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device) cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
def run_raysampler(): def run_raysampler() -> None:
raysampler(cameras=cameras) raysampler(cameras=cameras)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -128,7 +140,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -128,7 +140,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def init_raysampler( def init_raysampler(
raysampler_type=GridRaysampler, raysampler_type,
min_x=-1.0, min_x=-1.0,
max_x=1.0, max_x=1.0,
min_y=-1.0, min_y=-1.0,
...@@ -149,7 +161,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -149,7 +161,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
"max_depth": max_depth, "max_depth": max_depth,
} }
if issubclass(raysampler_type, GridRaysampler): if issubclass(raysampler_type, MultinomialRaysampler):
raysampler_params.update( raysampler_params.update(
{"image_width": image_width, "image_height": image_height} {"image_width": image_width, "image_height": image_height}
) )
...@@ -158,7 +170,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -158,7 +170,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
else: else:
raise ValueError(str(raysampler_type)) raise ValueError(str(raysampler_type))
if issubclass(raysampler_type, NDCGridRaysampler): if issubclass(raysampler_type, NDCMultinomialRaysampler):
# NDCGridRaysampler does not use min/max_x/y # NDCGridRaysampler does not use min/max_x/y
for k in ("min_x", "max_x", "min_y", "max_y"): for k in ("min_x", "max_x", "min_y", "max_y"):
del raysampler_params[k] del raysampler_params[k]
...@@ -191,8 +203,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -191,8 +203,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
for raysampler_type in ( for raysampler_type in (
MonteCarloRaysampler, MonteCarloRaysampler,
GridRaysampler, MultinomialRaysampler,
NDCGridRaysampler, NDCMultinomialRaysampler,
): ):
raysampler = TestRaysampling.init_raysampler( raysampler = TestRaysampling.init_raysampler(
...@@ -208,7 +220,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -208,7 +220,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
n_pts_per_ray=n_pts_per_ray, n_pts_per_ray=n_pts_per_ray,
) )
if issubclass(raysampler_type, NDCGridRaysampler): if issubclass(raysampler_type, NDCMultinomialRaysampler):
# adjust the gt bounds for NDCGridRaysampler # adjust the gt bounds for NDCGridRaysampler
if image_width >= image_height: if image_width >= image_height:
range_x = image_width / image_height range_x = image_width / image_height
...@@ -297,7 +309,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -297,7 +309,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
Checks the shapes of raysampler outputs. Checks the shapes of raysampler outputs.
""" """
if isinstance(raysampler, GridRaysampler): if isinstance(raysampler, MultinomialRaysampler):
spatial_size = [image_height, image_width] spatial_size = [image_height, image_width]
elif isinstance(raysampler, MonteCarloRaysampler): elif isinstance(raysampler, MonteCarloRaysampler):
spatial_size = [image_height * image_width] spatial_size = [image_height * image_width]
...@@ -386,7 +398,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -386,7 +398,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
# check that projected world points' xy coordinates # check that projected world points' xy coordinates
# range correctly between [minx/y, max/y] # range correctly between [minx/y, max/y]
if isinstance(raysampler, GridRaysampler): if isinstance(raysampler, MultinomialRaysampler):
# get the expected coordinates along each grid axis # get the expected coordinates along each grid axis
ys, xs = [ ys, xs = [
torch.linspace( torch.linspace(
...@@ -518,3 +530,51 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ...@@ -518,3 +530,51 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
) )
state = module1.state_dict() state = module1.state_dict()
module2.load_state_dict(state) module2.load_state_dict(state)
def test_jiggle(self):
# random data which is in ascending order along the last dimension
scale = 180
data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1)
out = _jiggle_within_stratas(data)
self.assertTupleEqual(out.shape, data.shape)
# Check `out` is in ascending order
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
self.assertConstant(out[..., :-1] < data[..., 1:], True)
self.assertConstant(data[..., :-1] < out[..., 1:], True)
jiggles = out - data
# jiggles is random between -scale/2 and scale/2
self.assertLess(jiggles.min(), -0.4 * scale)
self.assertGreater(jiggles.min(), -0.5 * scale)
self.assertGreater(jiggles.max(), 0.4 * scale)
self.assertLess(jiggles.max(), 0.5 * scale)
def test_safe_multinomial(self):
mask = [
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
]
tmask = torch.tensor(mask, dtype=torch.float32)
for _ in range(5):
random_scalar = torch.rand(1)
samples = _safe_multinomial(tmask * random_scalar, 3)
self.assertTupleEqual(samples.shape, (4, 3))
# samples[0] is exactly determined
self.assertConstant(samples[0], 0)
self.assertGreaterEqual(samples[1].min(), 0)
self.assertLessEqual(samples[1].max(), 1)
# samples[2] is exactly determined
self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2})
# samples[3] has enough sources, so must contain 3 distinct values.
self.assertLessEqual(samples[3].max(), 3)
self.assertEqual(len(set(samples[3].tolist())), 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment