Commit b19fe1de authored by Christoph Lassner's avatar Christoph Lassner Committed by Facebook GitHub Bot
Browse files

pulsar integration.

Summary:
This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code.

For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`.

Tasks addressed in the following diffs:
* Add the PyTorch3D interface,
* Add notebook examples and documentation (or adapt the existing ones to feature both interfaces).

Reviewed By: nikhilaravi

Differential Revision: D23947736

fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
parent d5650323
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Test number of channels."""
import logging
import sys
import unittest
from os import path
import torch
# fmt: off
# Make the mixin available.
sys.path.insert(0, path.join(path.dirname(__file__), ".."))
from common_testing import TestCaseMixin # isort:skip # noqa: E402
# fmt: on
sys.path.insert(0, path.join(path.dirname(__file__), "..", ".."))
devices = [torch.device("cuda"), torch.device("cpu")]
class TestChannels(TestCaseMixin, unittest.TestCase):
"""Test different numbers of channels."""
def test_basic(self):
"""Basic forward test."""
from pytorch3d.renderer.points.pulsar import Renderer
import torch
n_points = 10
width = 1_000
height = 1_000
renderer_1 = Renderer(width, height, n_points, n_channels=1)
renderer_3 = Renderer(width, height, n_points, n_channels=3)
renderer_8 = Renderer(width, height, n_points, n_channels=8)
# Generate sample data.
torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
vert_col = torch.rand(n_points, 8, dtype=torch.float32)
vert_rad = torch.rand(n_points, dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer_1 = renderer_1.to(device)
renderer_3 = renderer_3.to(device)
renderer_8 = renderer_8.to(device)
result_1 = (
renderer_1.forward(
vert_pos,
vert_col[:, :1],
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_1 = (
renderer_1.forward(
vert_pos,
vert_col[:, :1],
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
result_3 = (
renderer_3.forward(
vert_pos,
vert_col[:, :3],
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_3 = (
renderer_3.forward(
vert_pos,
vert_col[:, :3],
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
result_8 = (
renderer_8.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_8 = (
renderer_8.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
self.assertClose(result_1, result_3[:, :, :1])
self.assertClose(result_3, result_8[:, :, :3])
self.assertClose(hits_1, hits_3)
self.assertClose(hits_8, hits_3)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Test the sorting of the closest spheres."""
import logging
import os
import sys
import unittest
from os import path
import imageio
import numpy as np
import torch
# fmt: off
# Make the mixin available.
sys.path.insert(0, path.join(path.dirname(__file__), ".."))
from common_testing import TestCaseMixin # isort:skip # noqa: E402
# fmt: on
# Making sure you can run this, even if pulsar hasn't been installed yet.
sys.path.insert(0, path.join(path.dirname(__file__), "..", ".."))
devices = [torch.device("cuda"), torch.device("cpu")]
IN_REF_FP = path.join(path.dirname(__file__), "reference", "nr0000-in.pth")
OUT_REF_FP = path.join(path.dirname(__file__), "reference", "nr0000-out.pth")
class TestDepth(TestCaseMixin, unittest.TestCase):
"""Test different numbers of channels."""
def test_basic(self):
from pytorch3d.renderer.points.pulsar import Renderer
for device in devices:
gamma = 1e-5
max_depth = 15.0
min_depth = 5.0
renderer = Renderer(
256,
256,
10000,
orthogonal_projection=True,
right_handed_system=False,
n_channels=1,
).to(device)
data = torch.load(IN_REF_FP, map_location="cpu")
# data["pos"] = torch.rand_like(data["pos"])
# data["pos"][:, 0] = data["pos"][:, 0] * 2. - 1.
# data["pos"][:, 1] = data["pos"][:, 1] * 2. - 1.
# data["pos"][:, 2] = data["pos"][:, 2] + 9.5
result, result_info = renderer.forward(
data["pos"].to(device),
data["col"].to(device),
data["rad"].to(device),
data["cam_params"].to(device),
gamma,
min_depth=min_depth,
max_depth=max_depth,
return_forward_info=True,
bg_col=torch.zeros(1, device=device, dtype=torch.float32),
percent_allowed_difference=0.01,
)
sphere_ids = Renderer.sphere_ids_from_result_info_nograd(result_info)
depth_map = Renderer.depth_map_from_result_info_nograd(result_info)
depth_vis = (depth_map - depth_map[depth_map > 0].min()) * 200 / (
depth_map.max() - depth_map[depth_map > 0.0].min()
) + 50
if not os.environ.get("FB_TEST", False):
imageio.imwrite(
path.join(
path.dirname(__file__),
"test_out",
"test_depth_test_basic_depth.png",
),
depth_vis.cpu().numpy().astype(np.uint8),
)
# torch.save(
# data, path.join(path.dirname(__file__), "reference", "nr0000-in.pth")
# )
# torch.save(
# {"sphere_ids": sphere_ids, "depth_map": depth_map},
# path.join(path.dirname(__file__), "reference", "nr0000-out.pth"),
# )
# sys.exit(0)
reference = torch.load(OUT_REF_FP, map_location="cpu")
self.assertTrue(
torch.sum(
reference["sphere_ids"][..., 0].to(device) == sphere_ids[..., 0]
)
> 65530
)
self.assertClose(reference["depth_map"].to(device), depth_map)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Basic rendering test."""
import logging
import os
import sys
import unittest
from os import path
import imageio
import numpy as np
import torch
# Making sure you can run this, even if pulsar hasn't been installed yet.
sys.path.insert(0, path.join(path.dirname(__file__), "..", ".."))
LOGGER = logging.getLogger(__name__)
devices = [torch.device("cuda"), torch.device("cpu")]
class TestForward(unittest.TestCase):
"""Rendering tests."""
def test_bg_weight(self):
"""Test background reweighting."""
from pytorch3d.renderer.points.pulsar import Renderer
LOGGER.info("Setting up rendering test for 3 channels...")
n_points = 1
width = 1_000
height = 1_000
renderer = Renderer(width, height, n_points, background_normalized_depth=0.999)
vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
vert_col = torch.tensor([[0.3, 0.5, 0.7]], dtype=torch.float32)
vert_rad = torch.tensor([1.0], dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
hits = renderer.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_bg_weight.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
)
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_bg_weight_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
np.allclose(
result[500, 500, :].cpu().numpy(),
[1.0, 1.0, 1.0],
rtol=1e-2,
atol=1e-2,
)
)
def test_basic_3chan(self):
"""Test rendering one image with one sphere, 3 channels."""
from pytorch3d.renderer.points.pulsar import Renderer
LOGGER.info("Setting up rendering test for 3 channels...")
n_points = 1
width = 1_000
height = 1_000
renderer = Renderer(width, height, n_points)
vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
vert_col = torch.tensor([[0.3, 0.5, 0.7]], dtype=torch.float32)
vert_rad = torch.tensor([1.0], dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
hits = renderer.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_3chan.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
)
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_3chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
np.allclose(
result[500, 500, :].cpu().numpy(),
[0.3, 0.5, 0.7],
rtol=1e-2,
atol=1e-2,
)
)
def test_basic_1chan(self):
"""Test rendering one image with one sphere, 1 channel."""
from pytorch3d.renderer.points.pulsar import Renderer
LOGGER.info("Setting up rendering test for 1 channel...")
n_points = 1
width = 1_000
height = 1_000
renderer = Renderer(width, height, n_points, n_channels=1)
vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
vert_col = torch.tensor([[0.3]], dtype=torch.float32)
vert_rad = torch.tensor([1.0], dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
hits = renderer.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_1chan.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
)
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_1chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
np.allclose(
result[500, 500, :].cpu().numpy(), [0.3], rtol=1e-2, atol=1e-2
)
)
def test_basic_8chan(self):
"""Test rendering one image with one sphere, 8 channels."""
from pytorch3d.renderer.points.pulsar import Renderer
LOGGER.info("Setting up rendering test for 8 channels...")
n_points = 1
width = 1_000
height = 1_000
renderer = Renderer(width, height, n_points, n_channels=8)
vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
vert_col = torch.tensor(
[[1.0, 1.0, 1.0, 1.0, 1.0, 0.3, 0.5, 0.7]], dtype=torch.float32
)
vert_rad = torch.tensor([1.0], dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
hits = renderer.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_8chan.png",
),
(result[:, :, 5:8] * 255.0).cpu().to(torch.uint8).numpy(),
)
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_basic_8chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
np.allclose(
result[500, 500, 5:8].cpu().numpy(),
[0.3, 0.5, 0.7],
rtol=1e-2,
atol=1e-2,
)
)
self.assertTrue(
np.allclose(
result[500, 500, :5].cpu().numpy(), 1.0, rtol=1e-2, atol=1e-2
)
)
def test_principal_point(self):
"""Test shifting the principal point."""
from pytorch3d.renderer.points.pulsar import Renderer
LOGGER.info("Setting up rendering test for shifted principal point...")
n_points = 1
width = 1_000
height = 1_000
renderer = Renderer(width, height, n_points, n_channels=1)
vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
vert_col = torch.tensor([[0.0]], dtype=torch.float32)
vert_rad = torch.tensor([1.0], dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0, 0.0, 0.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
cam_params[-2] = -250.0
cam_params[-1] = -250.0
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_principal_point.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertTrue(
np.allclose(
result[750, 750, :].cpu().numpy(), [0.0], rtol=1e-2, atol=1e-2
)
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
cam_params[-2] = 250.0
cam_params[-1] = 250.0
renderer = renderer.to(device)
LOGGER.info("Rendering...")
# Measurements.
result = renderer.forward(
vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0
)
if not os.environ.get("FB_TEST", False):
imageio.imsave(
path.join(
path.dirname(__file__),
"test_out",
"test_forward_TestForward_test_principal_point.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
)
self.assertTrue(
np.allclose(
result[250, 250, :].cpu().numpy(), [0.0], rtol=1e-2, atol=1e-2
)
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.getLogger("pulsar.renderer").setLevel(logging.WARN)
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Test right hand/left hand system compatibility."""
import logging
import sys
import unittest
from os import path
import torch
# fmt: off
# Make the mixin available.
sys.path.insert(0, path.join(path.dirname(__file__), ".."))
from common_testing import TestCaseMixin # isort:skip # noqa: E402
# fmt: on
# Making sure you can run this, even if pulsar hasn't been installed yet.
sys.path.insert(0, path.join(path.dirname(__file__), "..", ".."))
devices = [torch.device("cuda"), torch.device("cpu")]
class TestHands(TestCaseMixin, unittest.TestCase):
"""Test right hand/left hand system compatibility."""
def test_basic(self):
"""Basic forward test."""
from pytorch3d.renderer.points.pulsar import Renderer
n_points = 10
width = 1000
height = 1000
renderer_left = Renderer(width, height, n_points, right_handed_system=False)
renderer_right = Renderer(width, height, n_points, right_handed_system=True)
# Generate sample data.
torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
vert_pos_neg = vert_pos.clone()
vert_pos_neg[:, 2] *= -1.0
vert_col = torch.rand(n_points, 3, dtype=torch.float32)
vert_rad = torch.rand(n_points, dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_pos_neg = vert_pos_neg.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer_left = renderer_left.to(device)
renderer_right = renderer_right.to(device)
result_left = (
renderer_left.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_left = (
renderer_left.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
result_right = (
renderer_right.forward(
vert_pos_neg,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_right = (
renderer_right.forward(
vert_pos_neg,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
self.assertClose(result_left, result_right)
self.assertClose(hits_left, hits_right)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.getLogger("pulsar.renderer").setLevel(logging.WARN)
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Tests for the orthogonal projection."""
import logging
import sys
import unittest
from os import path
import numpy as np
import torch
# Making sure you can run this, even if pulsar hasn't been installed yet.
sys.path.insert(0, path.join(path.dirname(__file__), ".."))
devices = [torch.device("cuda"), torch.device("cpu")]
class TestOrtho(unittest.TestCase):
"""Test the orthogonal projection."""
def test_basic(self):
"""Basic forward test of the orthogonal projection."""
from pytorch3d.renderer.points.pulsar import Renderer
n_points = 10
width = 1000
height = 1000
renderer_left = Renderer(
width,
height,
n_points,
right_handed_system=False,
orthogonal_projection=True,
)
renderer_right = Renderer(
width,
height,
n_points,
right_handed_system=True,
orthogonal_projection=True,
)
# Generate sample data.
torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
vert_pos_neg = vert_pos.clone()
vert_pos_neg[:, 2] *= -1.0
vert_col = torch.rand(n_points, 3, dtype=torch.float32)
vert_rad = torch.rand(n_points, dtype=torch.float32)
cam_params = torch.tensor(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0], dtype=torch.float32
)
for device in devices:
vert_pos = vert_pos.to(device)
vert_pos_neg = vert_pos_neg.to(device)
vert_col = vert_col.to(device)
vert_rad = vert_rad.to(device)
cam_params = cam_params.to(device)
renderer_left = renderer_left.to(device)
renderer_right = renderer_right.to(device)
result_left = (
renderer_left.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_left = (
renderer_left.forward(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
result_right = (
renderer_right.forward(
vert_pos_neg,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
)
.cpu()
.detach()
.numpy()
)
hits_right = (
renderer_right.forward(
vert_pos_neg,
vert_col,
vert_rad,
cam_params,
1.0e-1,
45.0,
percent_allowed_difference=0.01,
mode=1,
)
.cpu()
.detach()
.numpy()
)
self.assertTrue(np.allclose(result_left, result_right))
self.assertTrue(np.allclose(hits_left, hits_right))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.getLogger("pulsar.renderer").setLevel(logging.WARN)
unittest.main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""Test right hand/left hand system compatibility."""
import sys
import unittest
from os import path
import numpy as np
import torch
from torch import nn
sys.path.insert(0, path.join(path.dirname(__file__), ".."))
devices = [torch.device("cuda"), torch.device("cpu")]
n_points = 10
width = 1_000
height = 1_000
class SceneModel(nn.Module):
"""A simple model to demonstrate use in Modules."""
def __init__(self):
super(SceneModel, self).__init__()
from pytorch3d.renderer.points.pulsar import Renderer
self.gamma = 1.0
# Points.
torch.manual_seed(1)
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0
vert_pos[:, :, 2] += 25.0
vert_pos[:, :, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
self.register_parameter(
"vert_col",
nn.Parameter(
torch.zeros(1, n_points, 3, dtype=torch.float32), requires_grad=True
),
)
self.register_parameter(
"vert_rad",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32) * 0.001,
requires_grad=False,
),
)
self.register_parameter(
"vert_opy",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32), requires_grad=False
),
)
self.register_buffer(
"cam_params",
torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle,
0.0,
5.0,
2.0,
]
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
],
dtype=torch.float32,
),
)
self.renderer = Renderer(width, height, n_points)
def forward(self, cam=None):
if cam is None:
cam = self.cam_params
n_views = 8
else:
n_views = 1
return self.renderer.forward(
self.vert_pos.expand(n_views, -1, -1),
self.vert_col.expand(n_views, -1, -1),
self.vert_rad.expand(n_views, -1),
cam,
self.gamma,
45.0,
return_forward_info=True,
)
class TestSmallSpheres(unittest.TestCase):
"""Test small sphere rendering and gradients."""
def test_basic(self):
for device in devices:
# Set up model.
model = SceneModel().to(device)
angle = 0.0
for _ in range(50):
cam_control = torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle,
0.0,
5.0,
2.0,
]
],
dtype=torch.float32,
).to(device)
result, forw_info = model(cam=cam_control)
sphere_ids = model.renderer.sphere_ids_from_result_info_nograd(
forw_info
)
# Assert all spheres are rendered.
for idx in range(n_points):
self.assertTrue(
(sphere_ids == idx).sum() > 0, "Sphere ID %d missing!" % (idx)
)
# Visualize.
# result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
# cv2.imshow("res", result_im[0, :, :, ::-1])
# cv2.waitKey(0)
# Back-propagate some dummy gradients.
loss = ((result - torch.ones_like(result)).abs()).sum()
loss.backward()
# Now check whether the gradient arrives at every sphere.
self.assertTrue(torch.all(model.vert_col.grad[:, :, 0].abs() > 0.0))
angle += 0.15
if __name__ == "__main__":
unittest.main()
......@@ -27,28 +27,6 @@ class TestBuild(unittest.TestCase):
for k, v in counter.items():
self.assertEqual(v, 1, f"Too many files with stem {k}.")
@unittest.skipIf(in_conda_build, "In conda build")
def test_deprecated_usage(self):
# Check certain expressions do not occur in the csrc code
test_dir = Path(__file__).resolve().parent
source_dir = test_dir.parent / "pytorch3d" / "csrc"
files = sorted(source_dir.glob("**/*.*"))
self.assertGreater(len(files), 4)
patterns = [".type()", ".data()"]
for file in files:
with open(file) as f:
text = f.read()
for pattern in patterns:
found = pattern in text
msg = (
f"{pattern} found in {file.name}"
+ ", this has been deprecated."
)
self.assertFalse(found, msg)
@unittest.skipIf(in_conda_build, "In conda build")
def test_copyright(self):
test_dir = Path(__file__).resolve().parent
......@@ -63,6 +41,13 @@ class TestBuild(unittest.TestCase):
for extension in extensions:
for i in root_dir.glob(f"**/*.{extension}"):
print(i)
if str(i).endswith(
"pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py"
):
continue
if str(i).endswith("pytorch3d/csrc/pulsar/include/fastermath.h"):
continue
with open(i) as f:
firstline = f.readline()
if firstline.startswith(("# -*-", "#!")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment