Commit f825f7e4 authored by Darijan Gudelj's avatar Darijan Gudelj Committed by Facebook GitHub Bot
Browse files

Split Volumes class to data and location part

Summary: Split Volumes class to data and location part so that location part can be reused in planned VoxelGrid classes.

Reviewed By: bottler

Differential Revision: D38782015

fbshipit-source-id: 489da09c5c236f3b81961ce9b09edbd97afaa7c8
parent fdaaa299
...@@ -23,6 +23,7 @@ _VoxelSize = _ScalarOrVector ...@@ -23,6 +23,7 @@ _VoxelSize = _ScalarOrVector
_Translation = _Vector _Translation = _Vector
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] _TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
_ALL_CONTENT: slice = slice(0, None)
class Volumes: class Volumes:
...@@ -65,9 +66,9 @@ class Volumes: ...@@ -65,9 +66,9 @@ class Volumes:
VOLUME COORDINATES VOLUME COORDINATES
Additionally, the `Volumes` class keeps track of the locations of the Additionally, using the `VolumeLocator` class the `Volumes` class keeps track
centers of the volume cells in the local volume coordinates as well as in of the locations of the centers of the volume cells in the local volume
the world coordinates. coordinates as well as in the world coordinates.
Local coordinates: Local coordinates:
- Represent the locations of the volume cells in the local coordinate - Represent the locations of the volume cells in the local coordinate
...@@ -125,7 +126,7 @@ class Volumes: ...@@ -125,7 +126,7 @@ class Volumes:
appropriate `world_coordinates` argument. appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `Volumes._local_to_world_transform`. as a `Transform3d` object `Volumes.VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the Users can access the relevant transformations with the
`Volumes.get_world_to_local_coords_transform()` and `Volumes.get_world_to_local_coords_transform()` and
`Volumes.get_local_to_world_coords_transform()` `Volumes.get_local_to_world_coords_transform()`
...@@ -197,21 +198,24 @@ class Volumes: ...@@ -197,21 +198,24 @@ class Volumes:
# assign to the internal buffers # assign to the internal buffers
self._densities = densities_ self._densities = densities_
self._grid_sizes = grid_sizes
# assign a coordinate transformation member
self.locator = VolumeLocator(
batch_size=len(self),
grid_sizes=grid_sizes,
voxel_size=voxel_size,
volume_translation=volume_translation,
device=self.device,
)
# handle features # handle features
self._features = None self._features = None
if features is not None: if features is not None:
self._set_features(features) self._set_features(features)
# set the local_to_world transform
self._set_local_to_world_transform(
voxel_size=voxel_size, volume_translation=volume_translation
)
def _convert_densities_features_to_tensor( def _convert_densities_features_to_tensor(
self, x: _TensorBatch, var_name: str self, x: _TensorBatch, var_name: str
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.LongTensor]:
""" """
Handle the `densities` or `features` arguments to the constructor. Handle the `densities` or `features` arguments to the constructor.
""" """
...@@ -251,252 +255,9 @@ class Volumes: ...@@ -251,252 +255,9 @@ class Volumes:
f"{var_name} must be either a list or a tensor with " f"{var_name} must be either a list or a tensor with "
f"shape (batch_size, {var_name}_dim, H, W, D)." f"shape (batch_size, {var_name}_dim, H, W, D)."
) )
# pyre-ignore[7]
return x_tensor, x_shapes return x_tensor, x_shapes
def _voxel_size_translation_to_transform(
self,
voxel_size: torch.Tensor,
volume_translation: torch.Tensor,
batch_size: int,
) -> Transform3d:
"""
Converts the `voxel_size` and `volume_translation` constructor arguments
to the internal `Transform3d` object `local_to_world_transform`.
"""
volume_size_zyx = self.get_grid_sizes().float()
volume_size_xyz = volume_size_zyx[:, [2, 1, 0]]
# x_local = (
# (x_world + volume_translation) / (0.5 * voxel_size)
# ) / (volume_size - 1)
# x_world = (
# x_local * (volume_size - 1) * 0.5 * voxel_size
# ) - volume_translation
local_to_world_transform = Scale(
(volume_size_xyz - 1) * voxel_size * 0.5, device=self.device
).translate(-volume_translation)
return local_to_world_transform
def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `Volumes` constructor.
"""
err_msg = (
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
elif len(voxel_size.shape) == 2 and (
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
):
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `Volumes` constructor.
"""
err_msg = (
"`volume_translation` has to be either a 3-tuple of scalars, or"
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
)
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def _convert_volume_property_to_tensor(
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the Volumes constructor.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
or (x.shape[0] == batch_size and x.shape[1] == 3)
)
if not ok:
raise ValueError(err_msg)
if x.device != self.device:
x = x.to(self.device)
if x.shape[0] == 3 and len(x.shape) == 1:
x = x[None]
if x.shape[0] == 1:
x = x.repeat((batch_size, 1))
else:
raise ValueError(err_msg)
return x
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
# TODO(dnovotny): Implement caching of the coordinate grid.
return self._calculate_coordinate_grid(world_coordinates=world_coordinates)
def _calculate_coordinate_grid(
self, world_coordinates: bool = True
) -> torch.Tensor:
"""
Calculate the 3D coordinate grid of the volumetric grid either in
in local (`world_coordinates=False`) or
world coordinates (`world_coordinates=True`) .
"""
densities = self.densities()
ba, _, de, he, wi = densities.shape
grid_sizes = self.get_grid_sizes()
# generate coordinate axes
vol_axes = [
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
for r in (de, he, wi)
]
# generate per-coord meshgrids
Z, Y, X = meshgrid_ij(vol_axes)
# stack the coord grids ... this order matches the coordinate convention
# of torch.nn.grid_sample
vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1)
# get grid sizes relative to the maximal volume size
grid_sizes_relative = (
torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32)
- 1
) / (grid_sizes - 1).float()
if (grid_sizes_relative != 1.0).any():
# if any of the relative sizes != 1.0, adjust the grid
grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][
:, None, None, None
]
vol_coords_local *= grid_sizes_relative_reshape
vol_coords_local += grid_sizes_relative_reshape - 1
if world_coordinates:
vol_coords = self.local_to_world_coords(vol_coords_local)
else:
vol_coords = vol_coords_local
return vol_coords
def get_local_to_world_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
the local coordinate frame of the volume to world coordinates.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**local_to_world_transform**: A Transform3d object converting
points from local coordinates to the world coordinates.
"""
return self._local_to_world_transform
def get_world_to_local_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
world coordinates to the local coordinate frame of the volume.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**world_to_local_transform**: A Transform3d object converting
points from world coordinates to local coordinates.
"""
return self.get_local_to_world_coords_transform().inverse()
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_world.shape
return (
self.get_world_to_local_coords_transform()
.transform_points(points_3d_world.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_local.shape
return (
self.get_local_to_world_coords_transform()
.transform_points(points_3d_local.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def __len__(self) -> int: def __len__(self) -> int:
return self._densities.shape[0] return self._densities.shape[0]
...@@ -530,8 +291,7 @@ class Volumes: ...@@ -530,8 +291,7 @@ class Volumes:
densities=self.densities()[index], densities=self.densities()[index],
) )
# dont forget to update grid_sizes! # dont forget to update grid_sizes!
new._grid_sizes = self.get_grid_sizes()[index] self.locator._copy_transform_and_sizes(new.locator, index=index)
new._local_to_world_transform = self._local_to_world_transform[index]
return new return new
def features(self) -> Optional[torch.Tensor]: def features(self) -> Optional[torch.Tensor]:
...@@ -593,16 +353,6 @@ class Volumes: ...@@ -593,16 +353,6 @@ class Volumes:
x_list = struct_utils.padded_to_list(x, pad_sizes.tolist()) x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
return x_list return x_list
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self._grid_sizes
def update_padded( def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes": ) -> "Volumes":
...@@ -656,6 +406,525 @@ class Volumes: ...@@ -656,6 +406,525 @@ class Volumes:
) )
setattr(self, "_" + var_name, x_tensor) setattr(self, "_" + var_name, x_tensor)
def clone(self) -> "Volumes":
"""
Deep copy of Volumes object. All internal tensors are cloned
individually.
Returns:
new Volumes object.
"""
return copy.deepcopy(self)
def to(self, device: Device, copy: bool = False) -> "Volumes":
"""
Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
Returns:
Volumes object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
return self
other = self.clone()
if self.device == device_:
return other
other.device = device_
other._densities = self._densities.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.
other._features = self.features().to(device_)
self.locator._copy_transform_and_sizes(other.locator, device=device_)
other.locator = other.locator.to(device, copy)
return other
def cpu(self) -> "Volumes":
return self.to("cpu")
def cuda(self) -> "Volumes":
return self.to("cuda")
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self.locator.get_grid_sizes()
def get_local_to_world_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
the local coordinate frame of the volume to world coordinates.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**local_to_world_transform**: A Transform3d object converting
points from local coordinates to the world coordinates.
"""
return self.locator.get_local_to_world_coords_transform()
def get_world_to_local_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
world coordinates to the local coordinate frame of the volume.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**world_to_local_transform**: A Transform3d object converting
points from world coordinates to local coordinates.
"""
return self.get_local_to_world_coords_transform().inverse()
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
return self.locator.world_to_local_coords(points_3d_world)
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
return self.locator.local_to_world_coords(points_3d_local)
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
return self.locator.get_coord_grid(world_coordinates)
class VolumeLocator:
"""
The `VolumeLocator` class keeps track of the locations of the
centers of the volume cells in the local volume coordinates as well as in
the world coordinates for a voxel grid structure in 3D.
Local coordinates:
- Represent the locations of the volume cells in the local coordinate
frame of the volume.
- The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
3D local coordinate set to `[1, 1, 1]`.
- The first/second/third coordinate of each of the 3D per-voxel
XYZ vector denotes the horizontal/vertical/depth-wise position
respectively. I.e the order of the coordinate dimensions in the
volume is reversed w.r.t. the order of the 3D coordinate vectors.
- The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
are linearly interpolated over the spatial dimensions of the volume.
- Note that the convention is the same as for the 5D version of the
`torch.nn.functional.grid_sample` function called with
`align_corners==True`.
- Note that the local coordinate convention of `VolumeLocator`
(+X = left to right, +Y = top to bottom, +Z = away from the user)
is *different* from the world coordinate convention of the
renderer for `Meshes` or `Pointclouds`
(+X = right to left, +Y = bottom to top, +Z = away from the user).
World coordinates:
- These define the locations of the centers of the volume cells
in the world coordinates.
- They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world`
in the world coordinates:
```
x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation,
```
here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame.
Both `voxel_size` and `volume_translation` are specified in
the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows:
```
x_local = (
(x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1)
```
- For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`.
Coordinate tensors that denote the locations of each of the volume cells in
local / world coordinates (with shape `(depth x height x width x 3)`)
can be retrieved by calling the `VolumeLocator.get_coord_grid()` getter with the
appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the
`VolumeLocator.get_world_to_local_coords_transform()` and
`VolumeLocator.get_local_to_world_coords_transform()`
functions.
Example coordinate conversion:
- For a "trivial" volume with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, and the spatial size of
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds:
```
torch.nn.functional.grid_sample(
v.densities(),
v.get_coord_grid(world_coordinates=False),
align_corners=True,
) == v.densities(),
```
i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume.
"""
def __init__(
self,
batch_size: int,
grid_sizes: Union[
torch.LongTensor, Tuple[int, int, int], List[torch.LongTensor]
],
device: torch.device,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
**batch_size** : Batch size of the underlaying grids
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
a) tuple of form (H, W, D)
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
c) torch.Tensor of shape (batch_size, H, W, D)
H, W, D are height, width, depth respectively. If `grid_sizes` is a tuple than
all the grids in the batch have the same resolution.
**voxel_size**: Denotes the size of each volume voxel in world units.
Has to be one of:
a) A scalar (square voxels)
b) 3-tuple or a 3-list of scalars
c) a Tensor of shape (3,)
d) a Tensor of shape (minibatch, 3)
e) a Tensor of shape (minibatch, 1)
f) a Tensor of shape (1,) (square voxels)
**volume_translation**: Denotes the 3D translation of the center
of the volume in world units. Has to be one of:
a) 3-tuple or a 3-list of scalars
b) a Tensor of shape (3,)
c) a Tensor of shape (minibatch, 3)
d) a Tensor of shape (1,) (square voxels)
"""
self.device = device
self._batch_size = batch_size
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
# set the local_to_world transform
self._set_local_to_world_transform(
voxel_size=voxel_size, volume_translation=volume_translation
)
def _convert_grid_sizes2tensor(
self, x: Union[torch.LongTensor, List[torch.LongTensor], Tuple[int, int, int]]
) -> torch.LongTensor:
"""
Handle the grid_sizes argument to the constructor.
"""
if isinstance(x, (list, tuple)):
if isinstance(x[0], (torch.LongTensor, list, tuple)):
if self._batch_size != len(x):
raise ValueError("x should have a batch size of 'batch_size'")
# pyre-ignore[6]
if any(len(x_) != 3 for x_ in x):
raise ValueError(
"`grid_sizes` has to be a list of 3-dim tensors of shape: "
"(height, width, depth)"
)
x_shapes = torch.stack(
[
torch.tensor(
# pyre-ignore[6]
list(x_),
dtype=torch.long,
device=self.device,
)
for x_ in x
],
dim=0,
)
elif isinstance(x[0], int):
x_shapes = torch.stack(
[
torch.tensor(list(x), dtype=torch.long, device=self.device)
for _ in range(self._batch_size)
],
dim=0,
)
else:
raise ValueError(
"`grid_sizes` can be a list/tuple of int or torch.Tensor not of "
+ "{type(x[0])}."
)
elif torch.is_tensor(x):
if x.ndim != 2:
raise ValueError(
"`grid_sizes` has to be a 2-dim tensor of shape: (minibatch, 3)"
)
x_shapes = x.to(self.device)
else:
raise ValueError(
"grid_sizes must be either a list of tensors with shape (H, W, D), tensor with"
"shape (batch_size, H, W, D) or a tuple of (H, W, D)."
)
# pyre-ignore[7]
return x_shapes
def _voxel_size_translation_to_transform(
self,
voxel_size: torch.Tensor,
volume_translation: torch.Tensor,
batch_size: int,
) -> Transform3d:
"""
Converts the `voxel_size` and `volume_translation` constructor arguments
to the internal `Transform3d` object `local_to_world_transform`.
"""
volume_size_zyx = self.get_grid_sizes().float()
volume_size_xyz = volume_size_zyx[:, [2, 1, 0]]
# x_local = (
# (x_world + volume_translation) / (0.5 * voxel_size)
# ) / (volume_size - 1)
# x_world = (
# x_local * (volume_size - 1) * 0.5 * voxel_size
# ) - volume_translation
local_to_world_transform = Scale(
(volume_size_xyz - 1) * voxel_size * 0.5, device=self.device
).translate(-volume_translation)
return local_to_world_transform
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
# TODO(dnovotny): Implement caching of the coordinate grid.
return self._calculate_coordinate_grid(world_coordinates=world_coordinates)
def _calculate_coordinate_grid(
self, world_coordinates: bool = True
) -> torch.Tensor:
"""
Calculate the 3D coordinate grid of the volumetric grid either
in local (`world_coordinates=False`) or
world coordinates (`world_coordinates=True`) .
"""
ba, (de, he, wi) = self._batch_size, self._resolution
grid_sizes = self.get_grid_sizes()
# generate coordinate axes
vol_axes = [
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
for r in (de, he, wi)
]
# generate per-coord meshgrids
Z, Y, X = meshgrid_ij(vol_axes)
# stack the coord grids ... this order matches the coordinate convention
# of torch.nn.grid_sample
vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1)
# get grid sizes relative to the maximal volume size
grid_sizes_relative = (
torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32)
- 1
) / (grid_sizes - 1).float()
if (grid_sizes_relative != 1.0).any():
# if any of the relative sizes != 1.0, adjust the grid
grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][
:, None, None, None
]
vol_coords_local *= grid_sizes_relative_reshape
vol_coords_local += grid_sizes_relative_reshape - 1
if world_coordinates:
vol_coords = self.local_to_world_coords(vol_coords_local)
else:
vol_coords = vol_coords_local
return vol_coords
def get_local_to_world_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
the local coordinate frame of the volume to world coordinates.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**local_to_world_transform**: A Transform3d object converting
points from local coordinates to the world coordinates.
"""
return self._local_to_world_transform
def get_world_to_local_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
world coordinates to the local coordinate frame of the volume.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**world_to_local_transform**: A Transform3d object converting
points from world coordinates to local coordinates.
"""
return self.get_local_to_world_coords_transform().inverse()
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_world.shape
return (
self.get_world_to_local_coords_transform()
.transform_points(points_3d_world.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_local.shape
return (
self.get_local_to_world_coords_transform()
.transform_points(points_3d_local.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self._grid_sizes
def _set_local_to_world_transform( def _set_local_to_world_transform(
self, self,
voxel_size: _VoxelSize = 1.0, voxel_size: _VoxelSize = 1.0,
...@@ -690,17 +959,104 @@ class Volumes: ...@@ -690,17 +959,104 @@ class Volumes:
voxel_size, volume_translation, len(self) voxel_size, volume_translation, len(self)
) )
def clone(self) -> "Volumes": def _copy_transform_and_sizes(
self,
other: "VolumeLocator",
device: Optional[torch.device] = None,
index: Optional[
Union[int, List[int], Tuple[int], slice, torch.Tensor]
] = _ALL_CONTENT,
) -> None:
""" """
Deep copy of Volumes object. All internal tensors are cloned Copies the local to world transform and grid sizes to other VolumeLocator object
individually. and moves it to specified device. Operates in place on other.
Returns: Args:
new Volumes object. other: VolumeLocator object to which to copy
device: torch.device on which to put the result, defatults to self.device
index: Specifies which parts to copy.
Can be an int, slice, list of ints or a boolean or a long tensor.
Defaults to all items (`:`).
""" """
return copy.deepcopy(self) device = device if device is not None else self.device
other._grid_sizes = self._grid_sizes[index].to(device)
other._local_to_world_transform = self.get_local_to_world_coords_transform()[
index
].to(device)
def to(self, device: Device, copy: bool = False) -> "Volumes": def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
elif len(voxel_size.shape) == 2 and (
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
):
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"`volume_translation` has to be either a 3-tuple of scalars, or"
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
)
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def __len__(self) -> int:
return self._batch_size
def _convert_volume_property_to_tensor(
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the VolumeLocator constructor.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
or (x.shape[0] == batch_size and x.shape[1] == 3)
)
if not ok:
raise ValueError(err_msg)
if x.device != self.device:
x = x.to(self.device)
if x.shape[0] == 3 and len(x.shape) == 1:
x = x[None]
if x.shape[0] == 1:
x = x.repeat((batch_size, 1))
else:
raise ValueError(err_msg)
return x
def to(self, device: Device, copy: bool = False) -> "VolumeLocator":
""" """
Match the functionality of torch.Tensor.to() Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the If copy = True or the self Tensor is on a different device, the
...@@ -713,7 +1069,7 @@ class Volumes: ...@@ -713,7 +1069,7 @@ class Volumes:
copy: Boolean indicator whether or not to clone self. Default False. copy: Boolean indicator whether or not to clone self. Default False.
Returns: Returns:
Volumes object. VolumeLocator object.
""" """
device_ = make_device(device) device_ = make_device(device)
if not copy and self.device == device_: if not copy and self.device == device_:
...@@ -724,18 +1080,24 @@ class Volumes: ...@@ -724,18 +1080,24 @@ class Volumes:
return other return other
other.device = device_ other.device = device_
other._densities = self._densities.to(device_) other._grid_sizes = self._grid_sizes.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.
other._features = self.features().to(device_)
other._local_to_world_transform = self.get_local_to_world_coords_transform().to( other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
device_ device
) )
other._grid_sizes = self._grid_sizes.to(device_)
return other return other
def cpu(self) -> "Volumes": def clone(self) -> "VolumeLocator":
"""
Deep copy of VoluVolumeLocatormes object. All internal tensors are cloned
individually.
Returns:
new VolumeLocator object.
"""
return copy.deepcopy(self)
def cpu(self) -> "VolumeLocator":
return self.to("cpu") return self.to("cpu")
def cuda(self) -> "Volumes": def cuda(self) -> "VolumeLocator":
return self.to("cuda") return self.to("cuda")
...@@ -11,7 +11,7 @@ import unittest ...@@ -11,7 +11,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from pytorch3d.structures.volumes import Volumes from pytorch3d.structures.volumes import VolumeLocator, Volumes
from pytorch3d.transforms import Scale from pytorch3d.transforms import Scale
from .common_testing import TestCaseMixin from .common_testing import TestCaseMixin
...@@ -53,8 +53,8 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): ...@@ -53,8 +53,8 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for selectedIdx, index in indices: for selectedIdx, index in indices:
self.assertClose(selected.densities()[selectedIdx], v.densities()[index]) self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
self.assertClose( self.assertClose(
v._local_to_world_transform.get_matrix()[index], v.locator._local_to_world_transform.get_matrix()[index],
selected._local_to_world_transform.get_matrix()[selectedIdx], selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
) )
if selected.features() is not None: if selected.features() is not None:
self.assertClose(selected.features()[selectedIdx], v.features()[index]) self.assertClose(selected.features()[selectedIdx], v.features()[index])
...@@ -149,10 +149,55 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): ...@@ -149,10 +149,55 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
v_selected = v[index] v_selected = v[index]
def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
with self.subTest("VolumeLocator init with all sizes equal"):
grid_sizes = [resolution for _ in range(batch_size)]
locator_tuple = VolumeLocator(
batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
)
locator_list = VolumeLocator(
batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=torch.tensor(grid_sizes),
device=torch.device("cpu"),
)
expected_grid_sizes = torch.tensor(grid_sizes)
expected_resolution = resolution
assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_tuple._resolution)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
with self.subTest("VolumeLocator with different sizes in different grids"):
grid_sizes_list = [
torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
]
grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
locator_list = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_list,
device=torch.device("cpu"),
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_tensor,
device=torch.device("cpu"),
)
expected_grid_sizes = grid_sizes_tensor
expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32): def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
""" """
Test the correctness of the conversion between the internal Test the correctness of the conversion between the internal
Transform3D Volumes._local_to_world_transform and the initialization Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
from the translation and voxel_size. from the translation and voxel_size.
""" """
...@@ -440,7 +485,10 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): ...@@ -440,7 +485,10 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for var_name, var in vars(v).items(): for var_name, var in vars(v).items():
if var_name != "device": if var_name != "device":
if var is not None: if var is not None:
self.assertTrue(var.device.type == desired_device.type) self.assertTrue(
var.device.type == desired_device.type,
(var_name, var.device, desired_device),
)
else: else:
self.assertTrue(var.type == desired_device.type) self.assertTrue(var.type == desired_device.type)
...@@ -456,33 +504,38 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): ...@@ -456,33 +504,38 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
) )
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype) densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
volumes = Volumes(densities=densities, features=features) volumes = Volumes(densities=densities, features=features)
locator = VolumeLocator(
batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
)
for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
with self.subTest(f"Moving {name} from/to gpu and cpu"):
# Test support for str and torch.device # Test support for str and torch.device
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
converted_volumes = volumes.to("cpu") converted_obj = obj.to("cpu")
self.assertEqual(cpu_device, converted_volumes.device) self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, volumes.device) self.assertEqual(cpu_device, obj.device)
self.assertIs(volumes, converted_volumes) self.assertIs(obj, converted_obj)
converted_volumes = volumes.to(cpu_device) converted_obj = obj.to(cpu_device)
self.assertEqual(cpu_device, converted_volumes.device) self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, volumes.device) self.assertEqual(cpu_device, obj.device)
self.assertIs(volumes, converted_volumes) self.assertIs(obj, converted_obj)
cuda_device = torch.device("cuda:0") cuda_device = torch.device("cuda:0")
converted_volumes = volumes.to("cuda:0") converted_obj = obj.to("cuda:0")
self.assertEqual(cuda_device, converted_volumes.device) self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, volumes.device) self.assertEqual(cpu_device, obj.device)
self.assertIsNot(volumes, converted_volumes) self.assertIsNot(obj, converted_obj)
converted_volumes = volumes.to(cuda_device) converted_obj = obj.to(cuda_device)
self.assertEqual(cuda_device, converted_volumes.device) self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, volumes.device) self.assertEqual(cpu_device, obj.device)
self.assertIsNot(volumes, converted_volumes) self.assertIsNot(obj, converted_obj)
# Test device placement of internal tensors with self.subTest("Test device placement of internal tensors of Volumes"):
features = features.to(cuda_device) features = features.to(cuda_device)
densities = features.to(cuda_device) densities = features.to(cuda_device)
...@@ -511,6 +564,15 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): ...@@ -511,6 +564,15 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
else: else:
self._check_vars_on_device(volumes_, cuda_device) self._check_vars_on_device(volumes_, cuda_device)
with self.subTest("Test device placement of internal tensors of VolumeLocator"):
for device1, device2 in itertools.combinations(
(torch.device("cpu"), torch.device("cuda:0")), 2
):
locator = locator.to(device1)
locator = locator.to(device2)
self.assertEqual(locator._grid_sizes.device, device2)
self.assertEqual(locator._local_to_world_transform.device, device2)
def _check_padded(self, x_pad, x_list, grid_sizes): def _check_padded(self, x_pad, x_list, grid_sizes):
""" """
Check that padded tensors x_pad are the same as x_list tensors. Check that padded tensors x_pad are the same as x_list tensors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment