Commit 24f5f4a3 authored by Darijan Gudelj's avatar Darijan Gudelj Committed by Facebook GitHub Bot
Browse files

VoxelGridModule

Summary: Simple wrapper around voxel grids to make them a module

Reviewed By: bottler

Differential Revision: D38829762

fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
parent 6653f440
...@@ -8,14 +8,23 @@ ...@@ -8,14 +8,23 @@
This file contains classes that implement Voxel grids, both in their full resolution This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
https://arxiv.org/abs/2203.09517. TensoRF (https://arxiv.org/abs/2203.09517) paper.
In addition, the module VoxelGridModule implements a trainable instance of one of
these classes.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type from typing import ClassVar, Dict, Optional, Tuple, Type
import torch import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.structures.volumes import VolumeLocator from pytorch3d.structures.volumes import VolumeLocator
from .utils import interpolate_line, interpolate_plane, interpolate_volume from .utils import interpolate_line, interpolate_plane, interpolate_volume
...@@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase): ...@@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
) )
return shape_dict return shape_dict
class VoxelGridModule(Configurable, torch.nn.Module):
"""
A wrapper torch.nn.Module for the VoxelGrid classes, which
contains parameters that are needed to train the VoxelGrid classes.
Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,
which must be available in the registry. Default FullResolutionVoxelGrid.
voxel_grid: An instance of `VoxelGridBase`. This is the object which
this class wraps.
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
in world units.
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
init_std: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.1
init_mean: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.
"""
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase
extents: Tuple[float, float, float] = 1.0
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
init_mean: float = 0
def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes()
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
std=self.init_std,
)
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
Evaluates points in the world coordinate frame on the voxel_grid.
Args:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3)
Returns:
torch.Tensor of shape (n_points, n_features)
"""
locator = VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=self.extents,
volume_translation=self.translation,
device=next(self.params.values()).device,
)
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
...@@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( ...@@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid, CPFactorizedVoxelGrid,
FullResolutionVoxelGrid, FullResolutionVoxelGrid,
VMFactorizedVoxelGrid, VMFactorizedVoxelGrid,
VoxelGridModule,
) )
from pytorch3d.implicitron.tools.config import expand_args_fields from pytorch3d.implicitron.tools.config import expand_args_fields
...@@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): ...@@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expand_args_fields(FullResolutionVoxelGrid) expand_args_fields(FullResolutionVoxelGrid)
expand_args_fields(CPFactorizedVoxelGrid) expand_args_fields(CPFactorizedVoxelGrid)
expand_args_fields(VMFactorizedVoxelGrid) expand_args_fields(VMFactorizedVoxelGrid)
expand_args_fields(VoxelGridModule)
def _interpolate_1D( def _interpolate_1D(
self, points: torch.Tensor, vectors: torch.Tensor self, points: torch.Tensor, vectors: torch.Tensor
...@@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): ...@@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=10, n_features=10,
n_components=3, n_components=3,
) )
def test_voxel_grid_module_location(self, n_times=10):
"""
This checks the module uses locator correctly etc..
If we know that voxel grids work for (x, y, z) in local coordinates
to test if the VoxelGridModule does not have permuted dimensions we
create local coordinates, pass them through verified voxelgrids and
compare the result with the result that we get when we convert
coordinates to world and pass them through the VoxelGridModule
"""
for _ in range(n_times):
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
grid = VoxelGridModule(extents=extents)
local_point = torch.rand(1, 3) * 2 - 1
world_point = local_point * torch.tensor(extents) / 2
grid_values = grid.voxel_grid.values_type(**grid.params)
assert torch.allclose(
grid(world_point)[0, 0],
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment