Commit 34b1b4ab authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

defaulted grid_sizes in points2vols

Summary: Fix #873, that grid_sizes defaults to the wrong dtype in points2volumes code, and mask doesn't have a proper default.

Reviewed By: nikhilaravi

Differential Revision: D31503545

fbshipit-source-id: fa32a1a6074fc7ac7bdb362edfb5e5839866a472
parent 2f2466f4
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
from typing import Tuple, Optional from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
......
...@@ -364,7 +364,7 @@ def add_points_features_to_volume_densities_features( ...@@ -364,7 +364,7 @@ def add_points_features_to_volume_densities_features(
# grid sizes shape (minibatch, 3) # grid sizes shape (minibatch, 3)
grid_sizes = ( grid_sizes = (
torch.LongTensor(list(volume_densities.shape[2:])) torch.LongTensor(list(volume_densities.shape[2:]))
.to(volume_densities) .to(volume_densities.device)
.expand(volume_densities.shape[0], 3) .expand(volume_densities.shape[0], 3)
) )
...@@ -386,6 +386,10 @@ def add_points_features_to_volume_densities_features( ...@@ -386,6 +386,10 @@ def add_points_features_to_volume_densities_features(
splat = False splat = False
else: else:
raise ValueError('No such interpolation mode "%s"' % mode) raise ValueError('No such interpolation mode "%s"' % mode)
if mask is None:
mask = points_3d.new_ones(1).expand(points_3d.shape[:2])
volume_densities, volume_features = _points_to_volumes( volume_densities, volume_features = _points_to_volumes(
points_3d, points_3d,
points_features, points_features,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from itertools import product from itertools import product
from typing import Callable, Any from typing import Any, Callable
import torch import torch
from common_testing import get_random_cuda_device from common_testing import get_random_cuda_device
...@@ -14,6 +14,7 @@ from fvcore.common.benchmark import benchmark ...@@ -14,6 +14,7 @@ from fvcore.common.benchmark import benchmark
from pytorch3d.common.workaround import symeig3x3 from pytorch3d.common.workaround import symeig3x3
from test_symeig3x3 import TestSymEig3x3 from test_symeig3x3 import TestSymEig3x3
torch.set_num_threads(1) torch.set_num_threads(1)
CUDA_DEVICE = get_random_cuda_device() CUDA_DEVICE = get_random_cuda_device()
......
...@@ -16,6 +16,7 @@ from pytorch3d.io import save_obj ...@@ -16,6 +16,7 @@ from pytorch3d.io import save_obj
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation from pytorch3d.transforms.rotation_conversions import random_rotation
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3] OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data" DATA_DIR = get_tests_dir() / "data"
DEBUG = False DEBUG = False
......
...@@ -12,7 +12,10 @@ from typing import Tuple ...@@ -12,7 +12,10 @@ from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.ops import add_pointclouds_to_volumes from pytorch3d.ops import (
add_pointclouds_to_volumes,
add_points_features_to_volume_densities_features,
)
from pytorch3d.ops.points_to_volumes import _points_to_volumes from pytorch3d.ops.points_to_volumes import _points_to_volumes
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
...@@ -373,6 +376,17 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): ...@@ -373,6 +376,17 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
else: else:
self.assertTrue(torch.isfinite(field.grad.data).all()) self.assertTrue(torch.isfinite(field.grad.data).all())
def test_defaulted_arguments(self):
points = torch.rand(30, 1000, 3)
features = torch.rand(30, 1000, 5)
_, densities = add_points_features_to_volume_densities_features(
points,
features,
torch.zeros(30, 1, 32, 32, 32),
torch.zeros(30, 5, 32, 32, 32),
)
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)
def _check_volume_slice_color_density( def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3 self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment