Unverified Commit cdd2142d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

implicitron v0 (#1133)


Co-authored-by: default avatarJeremy Francis Reizenstein <bottler@users.noreply.github.com>
parent 0e377c68
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Sequence
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch import nn
from .base import ImplicitFunctionBase
@registry.register
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
feature_vector_size: int = 3
d_in: int = 3
d_out: int = 1
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
geometric_init: bool = True
bias: float = 1.0
skip_in: Sequence[int] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = 0
pooled_feature_dim: int = 0
encoding_dim: int = 0
def __post_init__(self):
super().__init__()
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None
if self.n_harmonic_functions_xyz > 0:
self.embed_fn = HarmonicEmbedding(
self.n_harmonic_functions_xyz, append_input=True
)
dims[0] = self.embed_fn.get_output_dim()
if self.pooled_feature_dim > 0:
dims[0] += self.pooled_feature_dim
if self.encoding_dim > 0:
dims[0] += self.encoding_dim
self.num_layers = len(dims)
out_dim = 0
layers = []
for layer_idx in range(self.num_layers - 1):
if layer_idx + 1 in self.skip_in:
out_dim = dims[layer_idx + 1] - dims[0]
else:
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
if self.geometric_init:
if layer_idx == self.num_layers - 2:
torch.nn.init.normal_(
lin.weight,
mean=math.pi ** 0.5 / dims[layer_idx] ** 0.5,
std=0.0001,
)
torch.nn.init.constant_(lin.bias, -self.bias)
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(
lin.weight[:, :3], 0.0, 2 ** 0.5 / out_dim ** 0.5
)
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
if self.weight_norm:
lin = nn.utils.weight_norm(lin)
layers.append(lin)
self.linear_layers = torch.nn.ModuleList(layers)
self.out_dim = out_dim
self.softplus = nn.Softplus(beta=100)
# pyre-fixme[14]: `forward` overrides method defined in `ImplicitFunctionBase`
# inconsistently.
def forward(
self,
# ray_bundle: RayBundle,
rays_points_world: torch.Tensor, # TODO: unify the APIs
fun_viewpool=None,
global_code=None,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
if rays_points_world.numel() == 0 or (
self.embed_fn is None and fun_viewpool is None and global_code is None
):
return torch.tensor(
[], device=rays_points_world.device, dtype=rays_points_world.dtype
).view(0, self.out_dim)
embedding = None
if self.embed_fn is not None:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
embedding = self.embed_fn(rays_points_world)
if fun_viewpool is not None:
assert rays_points_world.ndim == 2
pooled_feature = fun_viewpool(rays_points_world[None])
# TODO: pooled features are 4D!
embedding = torch.cat((embedding, pooled_feature), dim=-1)
if global_code is not None:
assert embedding.ndim == 2
assert global_code.shape[0] == 1 # TODO: generalize to batches!
# This will require changing raytracer code
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
embedding = torch.cat(
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
dim=-1,
)
x = embedding
for layer_idx in range(self.num_layers - 1):
if layer_idx in self.skip_in:
x = torch.cat([x, embedding], dim=-1) / 2 ** 0.5
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.linear_layers[layer_idx](x)
if layer_idx < self.num_layers - 2:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.softplus(x)
return x # TODO: unify the APIs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import List, Optional
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_harmonic_functions_xyz: int = 10
n_harmonic_functions_dir: int = 4
n_hidden_neurons_dir: int = 128
latent_dim: int = 0
input_xyz: bool = True
xyz_ray_dir_in_camera_coords: bool = False
color_dim: int = 3
"""
Args:
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
"""
def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for
# processing with a deep neural network.
self.harmonic_embedding_xyz = HarmonicEmbedding(
self.n_harmonic_functions_xyz, append_input=True
)
self.harmonic_embedding_dir = HarmonicEmbedding(
self.n_harmonic_functions_dir, append_input=True
)
if not self.input_xyz and self.latent_dim <= 0:
raise ValueError("The latent dimension has to be > 0 if xyz is not input!")
embedding_dim_dir = self.harmonic_embedding_dir.get_output_dim()
self.xyz_encoder = self._construct_xyz_encoder(
input_dim=self.get_xyz_embedding_dim()
)
self.intermediate_linear = torch.nn.Linear(
self.n_hidden_neurons_xyz, self.n_hidden_neurons_xyz
)
_xavier_init(self.intermediate_linear)
self.density_layer = torch.nn.Linear(self.n_hidden_neurons_xyz, 1)
_xavier_init(self.density_layer)
# Zero the bias of the density layer to avoid
# a completely transparent initialization.
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
self.color_layer = torch.nn.Sequential(
LinearWithRepeat(
self.n_hidden_neurons_xyz + embedding_dim_dir, self.n_hidden_neurons_dir
),
torch.nn.ReLU(True),
torch.nn.Linear(self.n_hidden_neurons_dir, self.color_dim),
torch.nn.Sigmoid(),
)
def get_xyz_embedding_dim(self):
return (
self.harmonic_embedding_xyz.get_output_dim() * int(self.input_xyz)
+ self.latent_dim
)
def _construct_xyz_encoder(self, input_dim: int):
raise NotImplementedError()
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
"""
This function takes per-point `features` predicted by `self.xyz_encoder`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self.color_layer((self.intermediate_linear(features), rays_embedding))
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True as this implicit function allows
multiple passes. Overridden from ImplicitFunctionBase.
"""
return True
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
"""
The forward function accepts the parametrizations of
3D points sampled along projection rays. The forward
pass is responsible for attaching a 3D vector
and a 1D scalar representing the point's
RGB color and opacity respectively.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz
if self.input_xyz
else None,
global_code=global_code,
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
camera=camera,
)
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features = self.xyz_encoder(embeds)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features = features.reshape(*rays_points_world.shape[:-1], -1)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities = self.density_layer(features)
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
if self.xyz_ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
rays_colors = self._get_colors(features, directions)
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
return raw_densities, rays_colors, {}
@registry.register
class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 1.0
n_hidden_neurons_xyz: int = 256
n_layers_xyz: int = 8
append_xyz: List[int] = field(default_factory=lambda: [5])
def _construct_xyz_encoder(self, input_dim: int):
return MLPWithInputSkips(
self.n_layers_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_skips=self.append_xyz,
)
@registry.register
class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 2.0
n_hidden_neurons_xyz: int = 80
n_layers_xyz: int = 2
append_xyz: List[int] = field(default_factory=lambda: [1])
def _construct_xyz_encoder(self, input_dim: int):
return TransformerWithInputSkips(
self.n_layers_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_skips=self.append_xyz,
dim_down_factor=self.transformer_dim_down_factor,
)
@staticmethod
def requires_pooling_without_aggregation() -> bool:
"""
Returns True as this implicit function needs
pooling without aggregation. Overridden from ImplicitFunctionBase.
"""
return True
class MLPWithInputSkips(torch.nn.Module):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
of a sequence of linear layers with ReLU activations.
Additionally, for a set of predefined layers `input_skips`, the forward pass
appends a skip tensor `z` to the output of the preceding layer.
Note that this follows the architecture described in the Supplementary
Material (Fig. 7) of [1].
References:
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
"""
def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
def _apply_affine_layer(self, layer, x, z):
mu_log_std = layer(z)
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: List[int] = [5],
skip_affine_trans: bool = False,
no_last_relu=False,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(n_layers):
dimin = hidden_dim if layeri > 0 else input_dim
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
if layeri > 0 and layeri in input_skips:
if skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(skip_dim, hidden_dim)
)
else:
dimin = hidden_dim + skip_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not no_last_relu or layeri + 1 < n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(input_skips)
self._skip_affine_trans = skip_affine_trans
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
"""
Args:
x: The input tensor of shape `(..., input_dim)`.
z: The input skip tensor of shape `(..., skip_dim)` which is appended
to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape `(..., output_dim)`.
"""
y = x
if z is None:
# if the skip tensor is None, we use `x` instead.
z = x
skipi = 0
for li, layer in enumerate(self.mlp):
if li in self._input_skips:
if self._skip_affine_trans:
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else:
y = torch.cat((y, z), dim=-1)
skipi += 1
y = layer(y)
return y
class TransformerWithInputSkips(torch.nn.Module):
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 64,
input_skips: List[int] = [5],
dim_down_factor: float = 1,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
self.first = torch.nn.Linear(input_dim, hidden_dim)
_xavier_init(self.first)
self.skip_linear = torch.nn.ModuleList()
layers_pool, layers_ray = [], []
dimout = 0
for layeri in range(n_layers):
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
print(f"Tr: {dimin} -> {dimout}")
for _i, l in enumerate((layers_pool, layers_ray)):
l.append(
TransformerEncoderLayer(
d_model=[dimin, dimout][_i],
nhead=4,
dim_feedforward=hidden_dim,
dropout=0.0,
d_model_out=dimout,
)
)
if layeri in input_skips:
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))
self.last = torch.nn.Linear(dimout, output_dim)
_xavier_init(self.last)
self.layers_pool, self.layers_ray = (
torch.nn.ModuleList(layers_pool),
torch.nn.ModuleList(layers_ray),
)
self._input_skips = set(input_skips)
def forward(
self,
x: torch.Tensor,
z: Optional[torch.Tensor] = None,
):
"""
Args:
x: The input tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
z: The input skip tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
which is appended to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
"""
if z is None:
# if the skip tensor is None, we use `x` instead.
z = x
y = self.first(x)
B, n_pool, n_rays, n_pts, dim = y.shape
# y_p in n_pool, n_pts, B x n_rays x dim
y_p = y.permute(1, 3, 0, 2, 4)
skipi = 0
dimh = dim
for li, (layer_pool, layer_ray) in enumerate(
zip(self.layers_pool, self.layers_ray)
):
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
if li in self._input_skips:
z_skip = self.skip_linear[skipi](z)
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
n_pool, n_pts * B * n_rays, dimh
)
skipi += 1
# n_pool x B*n_rays*n_pts x dim
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
dimh = y_pool_attn.shape[-1]
y_ray_attn = (
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
.permute(1, 0, 2, 3)
.reshape(n_pts, n_pool * B * n_rays, dimh)
)
# n_pts x n_pool*B*n_rays x dim
y_ray_attn, ray_attn = layer_ray(
y_ray_attn,
src_key_padding_mask=None,
)
y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)
y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)
W = torch.softmax(y[..., :1], dim=1)
y = (y * W).sum(dim=1)
y = self.last(y)
return y
class TransformerEncoderLayer(torch.nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
self.dropout = torch.nn.Dropout(dropout)
d_model_out = d_model if d_model_out <= 0 else d_model_out
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model_out)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.activation = torch.nn.functional.relu
def forward(self, src, src_mask=None, src_key_padding_mask=None):
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2, attn = self.self_attn(
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
d_out = src2.shape[-1]
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
src = self.norm2(src)
return src, attn
def _xavier_init(linear) -> None:
"""
Performs the Xavier weight initialization of the linear layer `linear`.
"""
torch.nn.init.xavier_uniform_(linear.weight.data)
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/vsitzmann/scene-representation-networks
# Copyright (c) 2019 Vincent Sitzmann
from typing import Any, Optional, Tuple, cast
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function
def _kaiming_normal_init(module: torch.nn.Module) -> None:
if isinstance(module, (torch.nn.Linear, LinearWithRepeat)):
torch.nn.init.kaiming_normal_(
module.weight, a=0.0, nonlinearity="relu", mode="fan_in"
)
class SRNRaymarchFunction(Configurable, torch.nn.Module):
n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs
n_hidden_units: int = 256
n_layers: int = 2
in_features: int = 3
out_features: int = 256
latent_dim: int = 0
xyz_in_camera_coords: bool = False
# The internal network can be set as an output of an SRNHyperNet.
# Note that, in order to avoid Pytorch's automatic registering of the
# raymarch_function module on construction, we input the network wrapped
# as a 1-tuple.
# raymarch_function should ideally be typed as Optional[Tuple[Callable]]
# but Omegaconf.structured doesn't like that. TODO: revisit after new
# release of omegaconf including https://github.com/omry/omegaconf/pull/749 .
raymarch_function: Any = None
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
self.n_harmonic_functions,
True,
)
+ self.latent_dim
)
if self.raymarch_function is not None:
self._net = self.raymarch_function[0]
else:
self._net = pytorch_prototyping.FCBlock(
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_features=input_embedding_dim,
out_features=self.out_features,
)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
"""
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: Set to None.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self._harmonic_embedding,
global_code=global_code,
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_in_camera_coords,
camera=camera,
)
# Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raymarch_features = self._net(
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
)
# raymarch_features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
raymarch_features = raymarch_features.reshape(*rays_points_world.shape[:-1], -1)
return raymarch_features, None
class SRNPixelGenerator(Configurable, torch.nn.Module):
n_harmonic_functions: int = 4
n_hidden_units: int = 256
n_hidden_units_color: int = 128
n_layers: int = 2
in_features: int = 256
out_features: int = 3
ray_dir_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
self._net = pytorch_prototyping.FCBlock(
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_features=self.in_features,
out_features=self.n_hidden_units,
)
self._density_layer = torch.nn.Linear(self.n_hidden_units, 1)
self._density_layer.apply(_kaiming_normal_init)
embedding_dim_dir = self._harmonic_embedding.get_output_dim(input_dims=3)
self._color_layer = torch.nn.Sequential(
LinearWithRepeat(
self.n_hidden_units + embedding_dim_dir,
self.n_hidden_units_color,
),
torch.nn.LayerNorm([self.n_hidden_units_color]),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(self.n_hidden_units_color, self.out_features),
)
self._color_layer.apply(_kaiming_normal_init)
# TODO: merge with NeuralRadianceFieldBase's _get_colors
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
"""
This function takes per-point `features` predicted by `self.net`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding = self._harmonic_embedding(rays_directions_normed)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._color_layer((features, rays_embedding))
def forward(
self,
raymarch_features: torch.Tensor,
ray_bundle: RayBundle,
camera: Optional[CamerasBase] = None,
**kwargs,
):
"""
Args:
raymarch_features: Features from the raymarching network of shape
`(minibatch, ..., self.in_features)`
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features = self._net(raymarch_features)
# features.shape = [minibatch x ... x self.n_hidden_units]
if self.ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
# NNs operate on the flattenned rays; reshaping to the correct spatial size
features = features.reshape(*raymarch_features.shape[:-1], -1)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities = self._density_layer(features)
rays_colors = self._get_colors(features, directions)
return raw_densities, rays_colors
class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
"""
This is a raymarching function which has a forward like SRNRaymarchFunction
but instead of the weights being parameters of the module, they
are the output of another network, the hypernet, which takes the global_code
as input. All the dataclass members of SRNRaymarchFunction are here with the
same meaning. In addition, there are members with names ending `_hypernet`
which affect the hypernet.
Because this class may be called repeatedly for the same global_code, the
output of the hypernet is cached in self.cached_srn_raymarch_function.
This member must be manually set to None whenever the global_code changes.
"""
n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs
n_hidden_units: int = 256
n_layers: int = 2
n_hidden_units_hypernet: int = 256
n_layers_hypernet: int = 1
in_features: int = 3
out_features: int = 256
latent_dim_hypernet: int = 0
latent_dim: int = 0
xyz_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
self.n_harmonic_functions,
True,
)
+ self.latent_dim
)
self._hypernet = hyperlayers.HyperFC(
hyper_in_ch=self.latent_dim_hypernet,
hyper_num_hidden_layers=self.n_layers_hypernet,
hyper_hidden_ch=self.n_hidden_units_hypernet,
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_ch=raymarch_input_embedding_dim,
out_ch=self.n_hidden_units,
)
self.cached_srn_raymarch_function: Optional[Tuple[SRNRaymarchFunction]] = None
def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]:
"""
Runs the hypernet and returns a 1-tuple containing the generated
srn_raymarch_function.
"""
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
net = self._hypernet(global_code)
# use the hyper-net generated network to instantiate the raymarch module
srn_raymarch_function = SRNRaymarchFunction(
n_harmonic_functions=self.n_harmonic_functions,
n_hidden_units=self.n_hidden_units,
n_layers=self.n_layers,
in_features=self.in_features,
out_features=self.out_features,
latent_dim=self.latent_dim,
xyz_in_camera_coords=self.xyz_in_camera_coords,
raymarch_function=(net,),
)
# move the generated raymarch function to the correct device
srn_raymarch_function.to(global_code.device)
return (srn_raymarch_function,)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
if global_code is None:
raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
# The raymarching network is cached in case the function is called repeatedly
# across LSTM iterations for the same global_code.
if self.cached_srn_raymarch_function is None:
# generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
(srn_raymarch_function,) = cast(
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
)
return srn_raymarch_function(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=None, # the hypernetwork takes the global code
)
@registry.register
# pyre-fixme[13]: Uninitialized attribute
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
raymarch_function: SRNRaymarchFunction
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
raymarch_features: Optional[torch.Tensor] = None,
**kwargs,
):
predict_colors = raymarch_features is not None
if predict_colors:
return self.pixel_generator(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle,
camera=camera,
**kwargs,
)
else:
return self.raymarch_function(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=global_code,
**kwargs,
)
@registry.register
# pyre-fixme[13]: Uninitialized attribute
class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
"""
This implicit function uses a hypernetwork to generate the
SRNRaymarchingFunction, and this is cached. Whenever the
global_code changes, `on_bind_args` must be called to clear
the cache.
"""
hypernet: SRNRaymarchHyperNet
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
raymarch_features: Optional[torch.Tensor] = None,
**kwargs,
):
predict_colors = raymarch_features is not None
if predict_colors:
return self.pixel_generator(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle,
camera=camera,
**kwargs,
)
else:
return self.hypernet(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=global_code,
**kwargs,
)
def on_bind_args(self):
"""
The global_code may have changed, so we reset the hypernet.
"""
self.hypernet.cached_srn_raymarch_function = None
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Callable, Optional
import torch
from pytorch3d.renderer.cameras import CamerasBase
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
"""
Expands the `global_code` of shape (minibatch, dim)
so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
and appends to the last dimension of `embeds`.
"""
bs = embeds.shape[0]
global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand(
*embeds.shape[:-1],
global_code.shape[-1],
)
return torch.cat([embeds, global_code_broadcast], dim=-1)
def create_embeddings_for_implicit_function(
xyz_world: torch.Tensor,
xyz_in_camera_coords: bool,
global_code: Optional[torch.Tensor],
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_in_camera_coords")
ray_points_for_embed = (
camera.get_world_to_view_transform()
.transform_points(xyz_world.view(bs, -1, 3))
.view(xyz_world.shape)
)
else:
ray_points_for_embed = xyz_world
if xyz_embedding_function is None:
embeds = torch.empty(
bs,
1,
math.prod(spatial_size),
pts_per_ray,
0,
dtype=xyz_world.dtype,
device=xyz_world.device,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
bs,
1,
math.prod(spatial_size),
pts_per_ray,
-1,
) # flatten spatial, add n_src dim
if fun_viewpool is not None:
# viewpooling
embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3))
embed_shape = (
bs,
embeds_viewpooled.shape[1],
math.prod(spatial_size),
pts_per_ray,
-1,
)
embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape)
if embeds is not None:
embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1)
else:
embeds = embeds_viewpooled
if global_code is not None:
# append the broadcasted global code to embeds
embeds = broadcast_global_code(embeds, global_code)
return embeds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Dict, Optional
import torch
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.renderer import utils as rend_utils
class ViewMetrics(torch.nn.Module):
def forward(
self,
image_sampling_grid: torch.Tensor,
images: Optional[torch.Tensor] = None,
images_pred: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
depths_pred: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
masks_pred: Optional[torch.Tensor] = None,
masks_crop: Optional[torch.Tensor] = None,
grad_theta: Optional[torch.Tensor] = None,
density_grid: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
mask_renders_by_pred: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Calculates various differentiable metrics useful for supervising
differentiable rendering pipelines.
Args:
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
image locations at which the predictions are defined.
All ground truth inputs are sampled at these
locations in order to extract values that correspond
to the predictions.
images: A tensor of shape `(B, H, W, 3)` containing ground truth
rgb values.
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
rgb values.
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
depth values.
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
depth values.
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
foreground masks.
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
of a gradient of a signed distance function w.r.t.
input 3D coordinates used to compute the eikonal loss.
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
`Hg x Wg x Dg` voxel grid of density values.
keys_prefix: A common prefix for all keys in the output dictionary
containing all metrics.
mask_renders_by_pred: If `True`, masks rendered images by the predicted
`masks_pred` prior to computing all rgb metrics.
Returns:
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
The calculated metrics are:
rgb_huber: A robust huber loss between `image_pred` and `image`.
rgb_mse: Mean squared error between `image_pred` and `image`.
rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`.
rgb_psnr_fg: Peak signal-to-noise ratio between the foreground
region of `image_pred` and `image` as defined by `mask`.
rgb_mse_fg: Mean squared error between the foreground
region of `image_pred` and `image` as defined by `mask`.
mask_neg_iou: (1 - intersection-over-union) between `mask_pred`
and `mask`.
mask_bce: Binary cross entropy between `mask_pred` and `mask`.
mask_beta_prior: A loss enforcing strictly binary values
of `mask_pred`: `log(mask_pred) + log(1-mask_pred)`
depth_abs: Mean per-pixel L1 distance between
`depth_pred` and `depth`.
depth_abs_fg: Mean per-pixel L1 distance between the foreground
region of `depth_pred` and `depth` as defined by `mask`.
eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`.
density_tv: The Total Variation regularizer of density
values in `density_grid` (sum of L1 distances of values
of all 4-neighbouring cells).
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
predicted depth values.
"""
# TODO: extract functions
# reshape from B x ... x DIM to B x DIM x -1 x 1
images_pred, masks_pred, depths_pred = [
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
]
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
image_sampling_grid = image_sampling_grid.reshape(
image_sampling_grid.shape[0], -1, 1, 2
)
# closure with the given image_sampling_grid
def sample(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
# eval all results in this size
images = sample(images, mode="bilinear")
depths = sample(depths, mode="nearest")
masks = sample(masks, mode="nearest")
masks_crop = sample(masks_crop, mode="nearest")
if masks_crop is None and images_pred is not None:
masks_crop = torch.ones_like(images_pred[:, :1])
if masks_crop is None and depths_pred is not None:
masks_crop = torch.ones_like(depths_pred[:, :1])
preds = {}
if images is not None and images_pred is not None:
# TODO: mask_renders_by_pred is always false; simplify
preds.update(
_rgb_metrics(
images,
images_pred,
masks,
masks_pred,
masks_crop,
mask_renders_by_pred,
)
)
if masks_pred is not None:
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
if masks is not None and masks_pred is not None:
preds["mask_neg_iou"] = utils.neg_iou_loss(
masks_pred, masks, mask=masks_crop
)
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
if depths is not None and depths_pred is not None:
assert masks_crop is not None
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
)
preds["depth_abs"] = abs_.mean()
if masks is not None:
mask = masks * masks_crop
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
)
preds["depth_abs_fg"] = abs_.mean()
# regularizers
if grad_theta is not None:
preds["eikonal"] = _get_eikonal_loss(grad_theta)
if density_grid is not None:
preds["density_tv"] = _get_grid_tv_loss(density_grid)
if depths_pred is not None:
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
if keys_prefix is not None:
preds = {(keys_prefix + k): v for k, v in preds.items()}
return preds
def _rgb_metrics(
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
):
assert masks_crop is not None
if mask_renders_by_pred:
images = images[..., masks_pred.reshape(-1), :]
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0)
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
preds = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
}
if masks is not None:
masks = masks_crop * masks
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
return preds
def _get_eikonal_loss(grad_theta):
return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5):
if log_domain:
if (grid <= -eps).any():
warnings.warn("Grid has negative values; this will produce NaN loss")
grid = torch.log(grid + eps)
# this is an isotropic version, note that it ignores last rows/cols
return torch.mean(
utils.safe_sqrt(
(grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2
+ (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2
+ (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2,
eps=1e-5,
)
)
def _get_depth_neg_penalty_loss(depth):
neg_penalty = depth.clamp(min=None, max=0.0) ** 2
return torch.mean(neg_penalty)
def _reshape_nongrid_var(x):
if x is None:
return None
ba, *_, dim = x.shape
return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools.point_cloud_utils import (
get_rgbd_point_cloud,
render_point_cloud_pytorch3d,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
class ModelDBIR(torch.nn.Module):
"""
A simple depth-based image rendering model.
"""
def __init__(
self,
image_size: int = 256,
bg_color: float = 0.0,
max_points: int = -1,
):
"""
Initializes a simple DBIR model.
Args:
image_size: The size of the rendered rectangular images.
bg_color: The color of the background.
max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled
to #max_size points without replacement.
"""
super().__init__()
self.image_size = image_size
self.bg_color = bg_color
self.max_points = max_points
def forward(
self,
camera: CamerasBase,
image_rgb: torch.Tensor,
depth_map: torch.Tensor,
fg_probability: torch.Tensor,
frame_type: List[str],
**kwargs,
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
"""
Given a set of input source cameras images and depth maps, unprojects
all RGBD maps to a colored point cloud and renders into the target views.
Args:
camera: A batch of `N` PyTorch3D cameras.
image_rgb: A batch of `N` images of shape `(N, 3, H, W)`.
depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`.
fg_probability: A batch of `N` foreground probability maps
of shape `(N, 1, H, W)`.
frame_type: A list of `N` strings containing frame type indicators
which specify target and source views.
Returns:
preds: A dict with the following fields:
nvs_prediction: The rendered colors, depth and mask
of the target views.
point_cloud: The point cloud of the scene. It's renders are
stored in `nvs_prediction`.
"""
is_known = is_known_frame(frame_type)
is_known_idx = torch.where(is_known)[0]
mask_fg = (fg_probability > 0.5).type_as(image_rgb)
point_cloud = get_rgbd_point_cloud(
camera[is_known_idx],
image_rgb[is_known_idx],
depth_map[is_known_idx],
mask_fg[is_known_idx],
)
pcl_size = int(point_cloud.num_points_per_cloud())
if (self.max_points > 0) and (pcl_size > self.max_points):
prm = torch.randperm(pcl_size)[: self.max_points]
point_cloud = Pointclouds(
point_cloud.points_padded()[:, prm, :],
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
features=point_cloud.features_padded()[:, prm, :],
)
is_target_idx = torch.where(~is_known)[0]
depth_render, image_render, mask_render = [], [], []
# render into target frames in a for loop to save memory
for tgt_idx in is_target_idx:
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
camera[int(tgt_idx)],
point_cloud,
render_size=(self.image_size, self.image_size),
point_radius=1e-2,
topk=10,
bg_color=self.bg_color,
)
_image_render = _image_render.clamp(0.0, 1.0)
# the mask is the set of pixels with opacity bigger than eps
_mask_render = (_mask_render > 1e-4).float()
depth_render.append(_depth_render)
image_render.append(_image_render)
mask_render.append(_mask_render)
nvs_prediction = NewViewSynthesisPrediction(
**{
k: torch.cat(v, dim=0)
for k, v in zip(
["depth_render", "image_render", "mask_render"],
[depth_render, image_render, mask_render],
)
}
)
preds = {
"nvs_prediction": nvs_prediction,
"point_cloud": point_cloud,
}
return preds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
class EvaluationMode(Enum):
TRAINING = "training"
EVALUATION = "evaluation"
class RenderSamplingMode(Enum):
MASK_SAMPLE = "mask_sample"
FULL_GRID = "full_grid"
@dataclass
class RendererOutput:
"""
A structure for storing the output of a renderer.
Args:
features: rendered features (usually RGB colors), (B, ..., C) tensor.
depth: rendered ray-termination depth map, in NDC coordinates, (B, ..., 1) tensor.
mask: rendered object mask, values in [0, 1], (B, ..., 1) tensor.
prev_stage: for multi-pass renderers (e.g. in NeRF),
a reference to the output of the previous stage.
normals: surface normals, for renderers that estimate them; (B, ..., 3) tensor.
points: ray-termination points in the world coordinates, (B, ..., 3) tensor.
aux: dict for implementation-specific renderer outputs.
"""
features: torch.Tensor
depths: torch.Tensor
masks: torch.Tensor
prev_stage: Optional[RendererOutput] = None
normals: Optional[torch.Tensor] = None
points: Optional[torch.Tensor] = None # TODO: redundant with depths
aux: Dict[str, Any] = field(default_factory=lambda: {})
class ImplicitFunctionWrapper(torch.nn.Module):
def __init__(self, fn: torch.nn.Module):
super().__init__()
self._fn = fn
self.bound_args = {}
def bind_args(self, **bound_args):
self.bound_args = bound_args
self._fn.on_bind_args()
def unbind_args(self):
self.bound_args = {}
def forward(self, *args, **kwargs):
return self._fn(*args, **{**kwargs, **self.bound_args})
class BaseRenderer(ABC, ReplaceableBase):
"""
Base class for all Renderer implementations.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
ray_bundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs
) -> RendererOutput:
"""
Each Renderer should implement its own forward function
that returns an instance of RendererOutput.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3)
containing the direction vectors of rendering rays.
lengths: A tensor of shape
(minibatch, ..., num_points_per_ray)containing the
lengths at which the ray points are sampled.
The coordinates of the points on the rays are thus computed
as `origins + lengths * directions`.
xys: A tensor of shape
(minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space.
implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the MultiPassEARenderer
allows specifying mulitple values in the list.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
**kwargs: In addition to the name args, custom keyword args can be specified.
For example in the SignedDistanceFunctionRenderer, an object_mask is
required which needs to be passed via the kwargs.
Returns:
instance of RendererOutput
"""
pass
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
@registry.register
class LSTMRenderer(BaseRenderer, torch.nn.Module):
"""
Implements the learnable LSTM raymarching function from SRN [1].
Settings:
num_raymarch_steps: The number of LSTM raymarching steps.
init_depth: Initializes the bias of the last raymarching LSTM layer so that
the farthest point from the camera reaches a far z-plane that
lies `init_depth` units from the camera plane.
init_depth_noise_std: The standard deviation of the random normal noise
added to the initial depth of each marched ray.
hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step.
verbose: If `True`, prints raymarching debug info.
References:
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
"Scene representation networks: Continuous 3d-structure-aware
neural scene representations." NeurIPS 2019.
"""
num_raymarch_steps: int = 10
init_depth: float = 17.0
init_depth_noise_std: float = 5e-4
hidden_size: int = 16
n_feature_channels: int = 256
verbose: bool = False
def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels,
hidden_size=self.hidden_size,
)
self._lstm.apply(_init_recurrent_weights)
_lstm_forget_gate_init(self._lstm)
self._out_layer = torch.nn.Linear(self.hidden_size, 1)
one_step = self.init_depth / self.num_raymarch_steps
self._out_layer.bias.data.fill_(one_step)
self._out_layer.weight.data.normal_(mean=0.0, std=1e-3)
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: A single-element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering, specifically the RayPointRefiner and the density_noise_std.
Returns:
instance of RendererOutput
"""
if len(implicit_functions) != 1:
raise ValueError("LSTM renderer expects a single implicit function.")
implicit_function = implicit_functions[0]
if ray_bundle.lengths.shape[-1] != 1:
raise ValueError(
"LSTM renderer requires a ray-bundle with a single point per ray"
+ " which is the initial raymarching point."
)
# jitter the initial depths
ray_bundle_t = ray_bundle._replace(
lengths=ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
)
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
signed_distance = torch.zeros_like(ray_bundle_t.lengths)
raymarch_features = None
for t in range(self.num_raymarch_steps + 1):
# move signed_distance along each ray
ray_bundle_t = ray_bundle_t._replace(
lengths=ray_bundle_t.lengths + signed_distance
)
# eval the raymarching function
raymarch_features, _ = implicit_function(
ray_bundle_t,
raymarch_features=None,
)
if self.verbose:
# print some stats
print(
f"{t}: mu={float(signed_distance.mean()):1.2e};"
+ f" std={float(signed_distance.std()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+ f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
)
if t == self.num_raymarch_steps:
break
# run the lstm marcher
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
state_h, state_c = self._lstm(
raymarch_features.view(-1, raymarch_features.shape[-1]),
states[-1],
)
if state_h.requires_grad:
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
# predict the next step size
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
# log the lstm states
states.append((state_h, state_c))
opacity_logits, features = implicit_function(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle_t,
)
mask = torch.sigmoid(opacity_logits)
depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm(
dim=-1, keepdim=True
)
return RendererOutput(
features=features[..., 0, :],
depths=depth,
masks=mask[..., 0, :],
)
def _init_recurrent_weights(self) -> None:
# copied from SRN codebase
for m in self.modules():
if type(m) in [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]:
for name, param in m.named_parameters():
if "weight_ih" in name:
torch.nn.init.kaiming_normal_(param.data)
elif "weight_hh" in name:
torch.nn.init.orthogonal_(param.data)
elif "bias" in name:
param.data.fill_(0)
def _lstm_forget_gate_init(lstm_layer) -> None:
# copied from SRN codebase
for name, parameter in lstm_layer.named_parameters():
if "bias" not in name:
continue
n = parameter.size(0)
start, end = n // 4, n // 2
parameter.data[start:end].fill_(1.0)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from .base import BaseRenderer, EvaluationMode, RendererOutput
from .ray_point_refiner import RayPointRefiner
from .raymarcher import GenericRaymarcher
@registry.register
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
"""
Implements the multi-pass rendering function, in particular,
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
opacity-based ray-point weights and then optionally (in case more implicit
functions are given) resamples points using importance sampling and evaluates
new weights.
During each ray marching pass, features, depth map, and masks
are integrated: Let o_i be the opacity estimated by the implicit function,
and d_i be the offset between points `i` and `i+1` along the respective ray.
Ray marching is performed using the following equations:
```
ray_opacity_n = cap_fn(sum_i=1^n cap_fn(d_i * o_i)),
weight_n = weight_fn(cap_fn(d_i * o_i), 1 - ray_opacity_{n-1}),
```
and the final rendered quantities are computed by a dot-product of ray values
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
See below for possible values of `cap_fn` and `weight_fn`.
Settings:
n_pts_per_ray_fine_training: The number of points sampled per ray for the
fine rendering pass during training.
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
fine rendering pass during evaluation.
stratified_sampling_coarse_training: Enable/disable stratified sampling during
training.
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
evaluation.
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
after sampling.
bg_color: The background color. A tuple of either 1 element or of D elements,
where D matches the feature dimensionality; it is broadcasted when necessary.
density_noise_std_train: Standard deviation of the noise added to the
opacity field.
capping_function: The capping function of the raymarcher.
Options:
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
- "cap1" (`cap_fn(x) = min(x, 1)`)
Set to "exponential" for the standard Emission Absorption raymarching.
weight_function: The weighting function of the raymarcher.
Options:
- "product" (`weight_fn(w, x) = w * x`)
- "minimum" (`weight_fn(w, x) = min(w, x)`)
Set to "product" for the standard Emission Absorption raymarching.
background_opacity: The raw opacity value (i.e. before exponentiation)
of the background.
blend_output: If `True`, alpha-blends the output renders with the
background color using the rendered opacity mask.
References:
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
fields for view synthesis." ECCV 2020.
"""
n_pts_per_ray_fine_training: int = 64
n_pts_per_ray_fine_evaluation: int = 64
stratified_sampling_coarse_training: bool = True
stratified_sampling_coarse_evaluation: bool = False
append_coarse_samples_to_fine: bool = True
bg_color: Tuple[float, ...] = (0.0,)
density_noise_std_train: float = 0.0
capping_function: str = "exponential" # exponential | cap1
weight_function: str = "product" # product | minimum
background_opacity: float = 1e10
blend_output: bool = False
def __post_init__(self):
super().__init__()
self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training,
random_sampling=self.stratified_sampling_coarse_training,
add_input_samples=self.append_coarse_samples_to_fine,
),
EvaluationMode.EVALUATION: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
random_sampling=self.stratified_sampling_coarse_evaluation,
add_input_samples=self.append_coarse_samples_to_fine,
),
}
self._raymarcher = GenericRaymarcher(
1,
self.bg_color,
capping_function=self.capping_function,
weight_function=self.weight_function,
background_opacity=self.background_opacity,
blend_output=self.blend_output,
)
def forward(
self,
ray_bundle,
implicit_functions=[],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: List of ImplicitFunctionWrappers which
define the implicit functions to be used sequentially in
the raymarching step. The output of raymarching with
implicit_functions[n-1] is refined, and then used as
input for raymarching with implicit_functions[n].
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering
Returns:
instance of RendererOutput
"""
if not implicit_functions:
raise ValueError("EA renderer expects implicit functions")
return self._run_raymarcher(
ray_bundle,
implicit_functions,
None,
evaluation_mode,
)
def _run_raymarcher(
self, ray_bundle, implicit_functions, prev_stage, evaluation_mode
):
density_noise_std = (
self.density_noise_std_train
if evaluation_mode == EvaluationMode.TRAINING
else 0.0
)
features, depth, mask, weights, aux = self._raymarcher(
*implicit_functions[0](ray_bundle),
ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std,
)
output = RendererOutput(
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
)
# we may need to make a recursive call
if len(implicit_functions) > 1:
fine_ray_bundle = self._refiners[evaluation_mode](ray_bundle, weights)
output = self._run_raymarcher(
fine_ray_bundle,
implicit_functions[1:],
output,
evaluation_mode,
)
return output
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
@expand_args_fields
# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized.
# pyre-fixme[13]: Attribute `random_sampling` is never initialized.
class RayPointRefiner(Configurable, torch.nn.Module):
"""
Implements the importance sampling of points along rays.
The input is a `RayBundle` object with a `ray_weights` tensor
which specifies the probabilities of sampling a point along each ray.
This raysampler is used for the fine rendering pass of NeRF.
As such, the forward pass accepts the RayBundle output by the
raysampling of the coarse rendering pass. Hence, it does not
take cameras as input.
Args:
n_pts_per_ray: The number of points to sample along each ray.
random_sampling: If `False`, returns equispaced percentiles of the
distribution defined by the input weights, otherwise performs
sampling from that distribution.
add_input_samples: Concatenates and returns the sampled values
together with the input samples.
"""
n_pts_per_ray: int
random_sampling: bool
add_input_samples: bool = True
def __post_init__(self) -> None:
super().__init__()
def forward(
self,
input_ray_bundle: RayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
z_vals = input_ray_bundle.lengths
with torch.no_grad():
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
det=not self.random_sampling,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
if self.add_input_samples:
# Add the new samples to the input ones.
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools import camera_utils
from pytorch3d.implicitron.tools.config import Configurable
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
from pytorch3d.renderer.cameras import CamerasBase
from .base import EvaluationMode, RenderSamplingMode
class RaySampler(Configurable, torch.nn.Module):
"""
Samples a fixed number of points along rays which are in turn sampled for
each camera in a batch.
This class utilizes `NDCMultinomialRaysampler` which allows to either
randomly sample rays from an input foreground saliency mask
(`RenderSamplingMode.MASK_SAMPLE`), or on a rectangular image grid
(`RenderSamplingMode.FULL_GRID`). The sampling mode can be set separately
for training and evaluation by setting `self.sampling_mode_training`
and `self.sampling_mode_training` accordingly.
The class allows two modes of sampling points along the rays:
1) Sampling between fixed near and far z-planes:
Active when `self.scene_extent <= 0`, samples points along each ray
with approximately uniform spacing of z-coordinates between
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
2) Adaptive near/far plane estimation around the world scene center:
Active when `self.scene_extent > 0`. Samples points on each
ray between near and far planes whose depths are determined based on
the distance from the camera center to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Similar to the sampling mode, the sampling parameters can be set separately
for training and evaluation.
Settings:
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
If `scene_extent <= 0`, the raysampler samples points between
`self.min_depth` and `self.max_depth` depths instead.
sampling_mode_training: The ray sampling mode for training. This should be a str
option from the RenderSamplingMode Enum
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
"""
image_width: int = 400
image_height: int = 400
scene_center: Tuple[float, float, float] = field(
default_factory=lambda: (0.0, 0.0, 0.0)
)
scene_extent: float = 0.0
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: int = 1024
min_depth: float = 0.1
max_depth: float = 8.0
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False
def __post_init__(self):
super().__init__()
self.scene_center = torch.FloatTensor(self.scene_center)
self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
EvaluationMode.EVALUATION: RenderSamplingMode(
self.sampling_mode_evaluation
),
}
self._raysamplers = {
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=self.min_depth,
max_depth=self.max_depth,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
),
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=self.min_depth,
max_depth=self.max_depth,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
),
}
def forward(
self,
cameras: CamerasBase,
evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None,
) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
sample_mask = None
if (
# pyre-fixme[29]
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
and mask is not None
):
sample_mask = torch.nn.functional.interpolate(
mask,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[int]`.
size=[self.image_height, self.image_width],
mode="nearest",
)[:, 0]
if self.scene_extent > 0.0:
# Override the min/max depth set in initialization based on the
# input cameras.
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self.scene_center, self.scene_extent
)
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle = self._raysamplers[evaluation_mode](
cameras=cameras,
mask=sample_mask,
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
)
return ray_bundle
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr
# Copyright (c) 2020 Lior Yariv
from typing import Any, Callable, Tuple
import torch
import torch.nn as nn
from pytorch3d.implicitron.tools.config import Configurable
class RayTracing(Configurable, nn.Module):
"""
Finds the intersection points of rays with the implicit surface defined
by a signed distance function (SDF). The algorithm follows the pipeline:
1. Initialise start and end points on rays by the intersections with
the circumscribing sphere.
2. Run sphere tracing from both ends.
3. Divide the untraced segments of non-convergent rays into uniform
intervals and find the one with the sign transition.
4. Run the secant method to estimate the point of the sign transition.
Args:
object_bounding_sphere: The radius of the initial sphere circumscribing
the object.
sdf_threshold: Absolute SDF value small enough for the sphere tracer
to consider it a surface.
line_search_step: Length of the backward correction on sphere tracing
iterations.
line_step_iters: Number of backward correction iterations.
sphere_tracing_iters: Maximum number of sphere tracing iterations
(the actual number of iterations may be smaller if all ray
intersections are found).
n_steps: Number of intervals sampled for unconvergent rays.
n_secant_steps: Number of iterations in the secant algorithm.
"""
object_bounding_sphere: float = 1.0
sdf_threshold: float = 5.0e-5
line_search_step: float = 0.5
line_step_iters: int = 1
sphere_tracing_iters: int = 10
n_steps: int = 100
n_secant_steps: int = 8
def __post_init__(self):
super().__init__()
def forward(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
object_mask: torch.BoolTensor,
ray_directions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
sdf: A callable that takes a (N, 3) tensor of points and returns
a tensor of (N,) SDF values.
cam_loc: A tensor of (B, N, 3) ray origins.
object_mask: A (N, 3) tensor of indicators whether a sampled pixel
corresponds to the rendered object or background.
ray_directions: A tensor of (B, N, 3) ray directions.
Returns:
curr_start_points: A tensor of (B*N, 3) found intersection points
with the implicit surface.
network_object_mask: A tensor of (B*N,) indicators denoting whether
intersections were found.
acc_start_dis: A tensor of (B*N,) distances from the ray origins
to intersrection points.
"""
batch_size, num_pixels, _ = ray_directions.shape
device = cam_loc.device
sphere_intersections, mask_intersect = _get_sphere_intersection(
cam_loc, ray_directions, r=self.object_bounding_sphere
)
(
curr_start_points,
unfinished_mask_start,
acc_start_dis,
acc_end_dis,
min_dis,
max_dis,
) = self.sphere_tracing(
batch_size,
num_pixels,
sdf,
cam_loc,
ray_directions,
mask_intersect,
sphere_intersections,
)
network_object_mask = acc_start_dis < acc_end_dis
# The non convergent rays should be handled by the sampler
sampler_mask = unfinished_mask_start
sampler_net_obj_mask = torch.zeros_like(
sampler_mask, dtype=torch.bool, device=device
)
if sampler_mask.sum() > 0:
sampler_min_max = torch.zeros((batch_size, num_pixels, 2), device=device)
sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[
sampler_mask
]
sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask]
sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(
sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask
)
curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
acc_start_dis[sampler_mask] = sampler_dists[sampler_mask]
network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask]
if not self.training:
return curr_start_points, network_object_mask, acc_start_dis
# in case we are training, we are updating curr_start_points and acc_start_dis for
ray_directions = ray_directions.reshape(-1, 3)
mask_intersect = mask_intersect.reshape(-1)
object_mask = object_mask.reshape(-1)
in_mask = ~network_object_mask & object_mask & ~sampler_mask
out_mask = ~object_mask & ~sampler_mask
# pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
mask_left_out = (in_mask | out_mask) & ~mask_intersect
if (
mask_left_out.sum() > 0
): # project the origin to the not intersect points on the sphere
cam_left_out = cam_loc.reshape(-1, 3)[mask_left_out]
rays_left_out = ray_directions[mask_left_out]
acc_start_dis[mask_left_out] = -torch.bmm(
rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1)
).squeeze()
curr_start_points[mask_left_out] = (
cam_left_out + acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out
)
mask = (in_mask | out_mask) & mask_intersect
if mask.sum() > 0:
min_dis[network_object_mask & out_mask] = acc_start_dis[
network_object_mask & out_mask
]
min_mask_points, min_mask_dist = self.minimal_sdf_points(
sdf, cam_loc, ray_directions, mask, min_dis, max_dis
)
curr_start_points[mask] = min_mask_points
acc_start_dis[mask] = min_mask_dist
return curr_start_points, network_object_mask, acc_start_dis
def sphere_tracing(
self,
batch_size: int,
num_pixels: int,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
mask_intersect: torch.Tensor,
sphere_intersections: torch.Tensor,
) -> Tuple[Any, Any, Any, Any, Any, Any]:
"""
Run sphere tracing algorithm for max iterations
from both sides of unit sphere intersection
Args:
batch_size:
num_pixels:
sdf:
cam_loc:
ray_directions:
mask_intersect:
sphere_intersections:
Returns:
curr_start_points:
unfinished_mask_start:
acc_start_dis:
acc_end_dis:
min_dis:
max_dis:
"""
device = cam_loc.device
sphere_intersections_points = (
cam_loc[..., None, :]
+ sphere_intersections[..., None] * ray_directions[..., None, :]
)
unfinished_mask_start = mask_intersect.reshape(-1).clone()
unfinished_mask_end = mask_intersect.reshape(-1).clone()
# Initialize start current points
curr_start_points = torch.zeros(batch_size * num_pixels, 3, device=device)
curr_start_points[unfinished_mask_start] = sphere_intersections_points[
:, :, 0, :
].reshape(-1, 3)[unfinished_mask_start]
acc_start_dis = torch.zeros(batch_size * num_pixels, device=device)
acc_start_dis[unfinished_mask_start] = sphere_intersections.reshape(-1, 2)[
unfinished_mask_start, 0
]
# Initialize end current points
curr_end_points = torch.zeros(batch_size * num_pixels, 3, device=device)
curr_end_points[unfinished_mask_end] = sphere_intersections_points[
:, :, 1, :
].reshape(-1, 3)[unfinished_mask_end]
acc_end_dis = torch.zeros(batch_size * num_pixels, device=device)
acc_end_dis[unfinished_mask_end] = sphere_intersections.reshape(-1, 2)[
unfinished_mask_end, 1
]
# Initialise min and max depth
min_dis = acc_start_dis.clone()
max_dis = acc_end_dis.clone()
# Iterate on the rays (from both sides) till finding a surface
iters = 0
# TODO: sdf should also pass info about batches
next_sdf_start = torch.zeros_like(acc_start_dis)
next_sdf_start[unfinished_mask_start] = sdf(
curr_start_points[unfinished_mask_start]
)
next_sdf_end = torch.zeros_like(acc_end_dis)
next_sdf_end[unfinished_mask_end] = sdf(curr_end_points[unfinished_mask_end])
while True:
# Update sdf
curr_sdf_start = torch.zeros_like(acc_start_dis)
curr_sdf_start[unfinished_mask_start] = next_sdf_start[
unfinished_mask_start
]
curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0
curr_sdf_end = torch.zeros_like(acc_end_dis)
curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end]
curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0
# Update masks
unfinished_mask_start = unfinished_mask_start & (
curr_sdf_start > self.sdf_threshold
)
unfinished_mask_end = unfinished_mask_end & (
curr_sdf_end > self.sdf_threshold
)
if (
unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0
) or iters == self.sphere_tracing_iters:
break
iters += 1
# Make step
# Update distance
acc_start_dis = acc_start_dis + curr_sdf_start
acc_end_dis = acc_end_dis - curr_sdf_end
# Update points
curr_start_points = (
cam_loc
+ acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)
curr_end_points = (
cam_loc
+ acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)
# Fix points which wrongly crossed the surface
next_sdf_start = torch.zeros_like(acc_start_dis)
next_sdf_start[unfinished_mask_start] = sdf(
curr_start_points[unfinished_mask_start]
)
next_sdf_end = torch.zeros_like(acc_end_dis)
next_sdf_end[unfinished_mask_end] = sdf(
curr_end_points[unfinished_mask_end]
)
not_projected_start = next_sdf_start < 0
not_projected_end = next_sdf_end < 0
not_proj_iters = 0
while (
not_projected_start.sum() > 0 or not_projected_end.sum() > 0
) and not_proj_iters < self.line_step_iters:
# Step backwards
acc_start_dis[not_projected_start] -= (
(1 - self.line_search_step) / (2 ** not_proj_iters)
) * curr_sdf_start[not_projected_start]
curr_start_points[not_projected_start] = (
cam_loc
+ acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)[not_projected_start]
acc_end_dis[not_projected_end] += (
(1 - self.line_search_step) / (2 ** not_proj_iters)
) * curr_sdf_end[not_projected_end]
curr_end_points[not_projected_end] = (
cam_loc
+ acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)[not_projected_end]
# Calc sdf
next_sdf_start[not_projected_start] = sdf(
curr_start_points[not_projected_start]
)
next_sdf_end[not_projected_end] = sdf(
curr_end_points[not_projected_end]
)
# Update mask
not_projected_start = next_sdf_start < 0
not_projected_end = next_sdf_end < 0
not_proj_iters += 1
unfinished_mask_start = unfinished_mask_start & (
acc_start_dis < acc_end_dis
)
unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis)
return (
curr_start_points,
unfinished_mask_start,
acc_start_dis,
acc_end_dis,
min_dis,
max_dis,
)
def ray_sampler(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
object_mask: torch.Tensor,
ray_directions: torch.Tensor,
sampler_min_max: torch.Tensor,
sampler_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample the ray in a given range and run secant on rays which have sign transition.
Args:
sdf:
cam_loc:
object_mask:
ray_directions:
sampler_min_max:
sampler_mask:
Returns:
"""
batch_size, num_pixels, _ = ray_directions.shape
device = cam_loc.device
n_total_pxl = batch_size * num_pixels
sampler_pts = torch.zeros(n_total_pxl, 3, device=device)
sampler_dists = torch.zeros(n_total_pxl, device=device)
intervals_dist = torch.linspace(0, 1, steps=self.n_steps, device=device).view(
1, 1, -1
)
pts_intervals = sampler_min_max[:, :, 0].unsqueeze(-1) + intervals_dist * (
sampler_min_max[:, :, 1] - sampler_min_max[:, :, 0]
).unsqueeze(-1)
points = (
cam_loc[..., None, :]
+ pts_intervals[..., None] * ray_directions[..., None, :]
)
# Get the non convergent rays
mask_intersect_idx = torch.nonzero(sampler_mask).flatten()
points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :]
pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask]
sdf_val_all = []
for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0):
sdf_val_all.append(sdf(pnts))
sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps)
tmp = torch.sign(sdf_val) * torch.arange(
self.n_steps, 0, -1, device=device, dtype=torch.float32
).reshape(1, self.n_steps)
# Force argmin to return the first min value
sampler_pts_ind = torch.argmin(tmp, -1)
sampler_pts[mask_intersect_idx] = points[
torch.arange(points.shape[0]), sampler_pts_ind, :
]
sampler_dists[mask_intersect_idx] = pts_intervals[
torch.arange(pts_intervals.shape[0]), sampler_pts_ind
]
true_surface_pts = object_mask.reshape(-1)[sampler_mask]
net_surface_pts = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0
# take points with minimal SDF value for P_out pixels
p_out_mask = ~(true_surface_pts & net_surface_pts)
n_p_out = p_out_mask.sum()
if n_p_out > 0:
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][
torch.arange(n_p_out), out_pts_idx, :
]
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
p_out_mask, :
][torch.arange(n_p_out), out_pts_idx]
# Get Network object mask
sampler_net_obj_mask = sampler_mask.clone()
sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False
# Run Secant method
secant_pts = (
net_surface_pts & true_surface_pts if self.training else net_surface_pts
)
n_secant_pts = secant_pts.sum()
if n_secant_pts > 0:
# Get secant z predictions
z_high = pts_intervals[
torch.arange(pts_intervals.shape[0]), sampler_pts_ind
][secant_pts]
sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][
secant_pts
]
z_low = pts_intervals[secant_pts][
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
]
sdf_low = sdf_val[secant_pts][
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
]
cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]]
ray_directions_secant = ray_directions.reshape((-1, 3))[
mask_intersect_idx[secant_pts]
]
z_pred_secant = self.secant(
sdf_low,
sdf_high,
z_low,
z_high,
cam_loc_secant,
ray_directions_secant,
# pyre-fixme[6]: For 7th param expected `Module` but got `(Tensor)
# -> Tensor`.
sdf,
)
# Get points
sampler_pts[mask_intersect_idx[secant_pts]] = (
cam_loc_secant + z_pred_secant.unsqueeze(-1) * ray_directions_secant
)
sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant
return sampler_pts, sampler_net_obj_mask, sampler_dists
def secant(
self,
sdf_low: torch.Tensor,
sdf_high: torch.Tensor,
z_low: torch.Tensor,
z_high: torch.Tensor,
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
sdf: nn.Module,
) -> torch.Tensor:
"""
Runs the secant method for interval [z_low, z_high] for n_secant_steps
"""
z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
for _ in range(self.n_secant_steps):
p_mid = cam_loc + z_pred.unsqueeze(-1) * ray_directions
sdf_mid = sdf(p_mid)
ind_low = sdf_mid > 0
if ind_low.sum() > 0:
z_low[ind_low] = z_pred[ind_low]
sdf_low[ind_low] = sdf_mid[ind_low]
ind_high = sdf_mid < 0
if ind_high.sum() > 0:
z_high[ind_high] = z_pred[ind_high]
sdf_high[ind_high] = sdf_mid[ind_high]
z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
return z_pred
def minimal_sdf_points(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
mask: torch.Tensor,
min_dis: torch.Tensor,
max_dis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find points with minimal SDF value on rays for P_out pixels
"""
n_mask_points = mask.sum()
n = self.n_steps
steps = torch.empty(n, device=cam_loc.device).uniform_(0.0, 1.0)
mask_max_dis = max_dis[mask].unsqueeze(-1)
mask_min_dis = min_dis[mask].unsqueeze(-1)
steps = (
steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis)
+ mask_min_dis
)
mask_points = cam_loc.reshape(-1, 3)[mask]
mask_rays = ray_directions[mask, :]
mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(
-1
) * mask_rays.unsqueeze(1).repeat(1, n, 1)
points = mask_points_all.reshape(-1, 3)
mask_sdf_all = []
for pnts in torch.split(points, 100000, dim=0):
mask_sdf_all.append(sdf(pnts))
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
min_vals, min_idx = mask_sdf_all.min(-1)
min_mask_points = mask_points_all.reshape(-1, n, 3)[
torch.arange(0, n_mask_points), min_idx
]
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
return min_mask_points, min_mask_dist
# TODO: support variable origins
def _get_sphere_intersection(
cam_loc: torch.Tensor, ray_directions: torch.Tensor, r: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
# Input: n_images x 3 ; n_images x n_rays x 3
# Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays
n_imgs, n_pix, _ = ray_directions.shape
device = cam_loc.device
# cam_loc = cam_loc.unsqueeze(-1)
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays
under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, dim=-1) ** 2 - r ** 2)
under_sqrt = under_sqrt.reshape(-1)
mask_intersect = under_sqrt > 0
sphere_intersections = torch.zeros(n_imgs * n_pix, 2, device=device)
sphere_intersections[mask_intersect] = torch.sqrt(
under_sqrt[mask_intersect]
).unsqueeze(-1) * torch.tensor([-1.0, 1.0], device=device)
sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[
mask_intersect
].unsqueeze(-1)
sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2)
sphere_intersections = sphere_intersections.clamp_min(0.0)
mask_intersect = mask_intersect.reshape(n_imgs, n_pix)
return sphere_intersections, mask_intersect
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple, Union
import torch
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
_TTensor = torch.Tensor
class GenericRaymarcher(torch.nn.Module):
"""
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
and NeuralVolumes' Accumulative ray marcher. It additionally returns
the rendering weights that can be used in the NVS pipeline to carry out
the importance ray-sampling in the refining pass.
Different from `EmissionAbsorptionRaymarcher`, it takes raw
(non-exponentiated) densities.
Args:
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
"""
def __init__(
self,
surface_thickness: int = 1,
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
capping_function: str = "exponential", # exponential | cap1
weight_function: str = "product", # product | minimum
background_opacity: float = 0.0,
density_relu: bool = True,
blend_output: bool = True,
):
"""
Args:
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super().__init__()
self.surface_thickness = surface_thickness
self.density_relu = density_relu
self.background_opacity = background_opacity
self.blend_output = blend_output
if not isinstance(bg_color, torch.Tensor):
bg_color = torch.tensor(bg_color)
if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
self.register_buffer("_bg_color", bg_color, persistent=False)
self._capping_function: Callable[[_TTensor], _TTensor] = {
"exponential": lambda x: 1.0 - torch.exp(-x),
"cap1": lambda x: x.clamp(max=1.0),
}[capping_function]
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
"product": lambda curr, acc: curr * acc,
"minimum": lambda curr, acc: torch.minimum(curr, acc),
}[weight_function]
def forward(
self,
rays_densities: torch.Tensor,
rays_features: torch.Tensor,
aux: Dict[str, Any],
ray_lengths: torch.Tensor,
density_noise_std: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray, 1)`.
rays_features: Per-ray feature values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
density_noise_std: the magnitude of the noise added to densities.
Returns:
features: A tensor of shape `(..., feature_dim)` containing
the rendered features for each ray.
depth: A tensor of shape `(..., 1)` containing estimated depth.
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
weights: A tensor of shape `(..., n_points_per_ray)` containing
the ray-specific non-negative opacity weights. In general, they
don't sum to 1 but do not overcome it, i.e.
`(weights.sum(dim=-1) <= 1.0).all()` holds.
"""
_check_raymarcher_inputs(
rays_densities,
rays_features,
ray_lengths,
z_can_be_none=True,
features_can_be_none=False,
density_1d=True,
)
deltas = torch.cat(
(
ray_lengths[..., 1:] - ray_lengths[..., :-1],
self.background_opacity * torch.ones_like(ray_lengths[..., :1]),
),
dim=-1,
)
rays_densities = rays_densities[..., 0]
if density_noise_std > 0.0:
rays_densities = (
rays_densities + torch.randn_like(rays_densities) * density_noise_std
)
if self.density_relu:
rays_densities = torch.relu(rays_densities)
weighted_densities = deltas * rays_densities
capped_densities = self._capping_function(weighted_densities)
rays_opacities = self._capping_function(
torch.cumsum(weighted_densities, dim=-1)
)
opacities = rays_opacities[..., -1:]
absorption_shifted = (-rays_opacities + 1.0).roll(
self.surface_thickness, dims=-1
)
absorption_shifted[..., : self.surface_thickness] = 1.0
weights = self._weight_function(capped_densities, absorption_shifted)
features = (weights[..., None] * rays_features).sum(dim=-2)
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
alpha = opacities if self.blend_output else 1
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
raise ValueError("Wrong number of background color channels.")
features = alpha * features + (1 - opacities) * self._bg_color
return features, depth, opacities, weights, aux
# @lint-ignore-every LICENSELINT
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
class RayNormalColoringNetwork(torch.nn.Module):
def __init__(
self,
feature_vector_size=3,
mode="idr",
d_in=9,
d_out=3,
dims=(512, 512, 512, 512),
weight_norm=True,
n_harmonic_functions_dir=0,
pooled_feature_dim=0,
):
super().__init__()
self.mode = mode
self.output_dimensions = d_out
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None
if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True
)
dims[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
print("Pooled features in rendering network.")
dims[0] += pooled_feature_dim
self.num_layers = len(dims)
layers = []
for layer_idx in range(self.num_layers - 1):
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
layers.append(lin)
self.linear_layers = torch.nn.ModuleList(layers)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(
self,
feature_vectors: torch.Tensor,
points,
normals,
ray_bundle: RayBundle,
masks=None,
pooling_fn=None,
):
if masks is not None and not masks.any():
return torch.zeros_like(normals)
view_dirs = ray_bundle.directions
if masks is not None:
# in case of IDR, other outputs are passed here after applying the mask
view_dirs = view_dirs.reshape(view_dirs.shape[0], -1, 3)[
:, masks.reshape(-1)
]
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
if self.mode == "idr":
rendering_input = torch.cat(
[points, view_dirs, normals, feature_vectors], dim=-1
)
elif self.mode == "no_view_dir":
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
elif self.mode == "no_normal":
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
else:
raise ValueError(f"Unsupported rendering mode: {self.mode}")
if pooling_fn is not None:
featspool = pooling_fn(points[None])[0]
rendering_input = torch.cat((rendering_input, featspool), dim=-1)
x = rendering_input
for layer_idx in range(self.num_layers - 1):
x = self.linear_layers[layer_idx](x)
if layer_idx < self.num_layers - 2:
x = self.relu(x)
x = self.tanh(x)
return x
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import functools
import math
from typing import List, Optional, Tuple
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
from pytorch3d.implicitron.tools.utils import evaluating
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_tracing import RayTracing
from .rgb_net import RayNormalColoringNetwork
@registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork
)
bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0
def __post_init__(
self,
):
super().__init__()
render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError(
f"Background color should have {render_features_dimensions} entries."
)
self.ray_tracer = RayTracing(**self.ray_tracer_args)
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
self.ray_normal_coloring_network_args[
"feature_vector_size"
] = render_features_dimensions
self._rgb_network = RayNormalColoringNetwork(
**self.ray_normal_coloring_network_args
)
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
object_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: single element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
kwargs:
object_mask: BoolTensor, denoting the silhouette of the object.
This is a required keyword argument for SignedDistanceFunctionRenderer
Returns:
instance of RendererOutput
"""
if len(implicit_functions) != 1:
raise ValueError(
"SignedDistanceFunctionRenderer supports only single pass."
)
if object_mask is None:
raise ValueError("Expected object_mask to be provided in the kwargs")
object_mask = object_mask.bool()
implicit_function = implicit_functions[0]
implicit_function_gradient = functools.partial(gradient, implicit_function)
# object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
num_pixels = math.prod(spatial_size)
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
object_mask = object_mask.reshape(batch_size, -1)
with torch.no_grad(), evaluating(implicit_function):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[
:, 0
], # TODO: get rid of this wrapper
cam_loc=cam_loc,
object_mask=object_mask,
ray_directions=ray_dirs,
)
# TODO: below, cam_loc might as well be different
depth = dists.reshape(batch_size, num_pixels, 1)
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
sdf_output = implicit_function(points)[:, 0:1]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs = ray_dirs.reshape(-1, 3)
object_mask = object_mask.reshape(-1)
# TODO: move it to loss computation
if evaluation_mode == EvaluationMode.TRAINING:
surface_mask = network_object_mask & object_mask
surface_points = points[surface_mask]
surface_dists = dists[surface_mask].unsqueeze(-1)
surface_ray_dirs = ray_dirs[surface_mask]
surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask]
surface_output = sdf_output[surface_mask]
N = surface_points.shape[0]
# Sample points for the eikonal loss
# pyre-fixme[9]
eik_bounding_box: float = self.object_bounding_sphere
n_eik_points = batch_size * num_pixels // 2
eikonal_points = torch.empty(
n_eik_points, 3, device=self._bg_color.device
).uniform_(-eik_bounding_box, eik_bounding_box)
eikonal_pixel_points = points.clone()
eikonal_pixel_points = eikonal_pixel_points.detach()
eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0)
points_all = torch.cat([surface_points, eikonal_points], dim=0)
output = implicit_function(surface_points)
surface_sdf_values = output[
:N, 0:1
].detach() # how is it different from sdf_output?
g = implicit_function_gradient(points_all)
surface_points_grad = g[:N, 0, :].clone().detach()
grad_theta = g[N:, 0, :]
differentiable_surface_points = _sample_network(
surface_output,
surface_sdf_values,
surface_points_grad,
surface_dists,
surface_cam_loc,
surface_ray_dirs,
)
else:
surface_mask = network_object_mask
differentiable_surface_points = points[surface_mask]
grad_theta = None
empty_render = differentiable_surface_points.shape[0] == 0
features = implicit_function(differentiable_surface_points)[None, :, 1:]
normals_full = features.new_zeros(
batch_size, *spatial_size, 3, requires_grad=empty_render
)
render_full = (
features.new_ones(
batch_size,
*spatial_size,
self.render_features_dimensions,
requires_grad=empty_render,
)
* self._bg_color
)
mask_full = features.new_ones(
batch_size, *spatial_size, 1, requires_grad=empty_render
)
if not empty_render:
normals = implicit_function_gradient(differentiable_surface_points)[
None, :, 0, :
]
normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[
surface_mask
] = self._rgb_network( # pyre-fixme[29]:
features,
differentiable_surface_points[None],
normals,
ray_bundle,
surface_mask[None, :, None],
pooling_fn=None, # TODO
)
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
-self.soft_mask_alpha * sdf_output[~surface_mask]
)
# scatter points with surface_mask
points_full = ray_bundle.origins.detach().clone()
points_full.view(-1, 3)[surface_mask] = differentiable_surface_points
# TODO: it is sparse here but otherwise dense
return RendererOutput(
features=render_full,
normals=normals_full,
depths=depth.reshape(batch_size, *spatial_size, 1),
masks=mask_full, # this is a differentiable approximation, see (7) in the paper
points=points_full,
aux={"grad_theta": grad_theta}, # TODO: will be moved to eikonal loss
# TODO: do we need sdf_output, grad_theta? Only for loss probably
)
def _sample_network(
surface_output,
surface_sdf_values,
surface_points_grad,
surface_dists,
surface_cam_loc,
surface_ray_dirs,
eps=1e-4,
):
# t -> t(theta)
surface_ray_dirs_0 = surface_ray_dirs.detach()
surface_points_dot = torch.bmm(
surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1)
).squeeze(-1)
dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / (
surface_points_dot.abs().clip(eps) * dot_sign
)
# t(theta) -> x(theta,c,v)
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
return surface_points_theta_c_v
@torch.enable_grad()
def gradient(module, x):
x.requires_grad_(True)
y = module.forward(x)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
return gradients.unsqueeze(1)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import math
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as Fu
import torchvision
from pytorch3d.implicitron.tools.config import Configurable
logger = logging.getLogger(__name__)
MASK_FEATURE_NAME = "mask"
IMAGE_FEATURE_NAME = "image"
_FEAT_DIMS = {
"resnet18": (64, 128, 256, 512),
"resnet34": (64, 128, 256, 512),
"resnet50": (256, 512, 1024, 2048),
"resnet101": (256, 512, 1024, 2048),
"resnet152": (256, 512, 1024, 2048),
}
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
"""
Implements an image feature extractor. Depending on the settings allows
to extract:
- deep features: A CNN ResNet backbone from torchvision (with/without
pretrained weights) which extracts deep features.
- masks: Segmentation masks.
- images: Raw input RGB images.
Settings:
name: name of the resnet backbone (from torchvision)
pretrained: If true, will load the pretrained weights
stages: List of stages from which to extract features.
Features from each stage are returned as key value
pairs in the forward function
normalize_image: If set will normalize the RGB values of
the image based on the Resnet mean/std
image_rescale: If not 1.0, this rescale factor will be
used to resize the image
first_max_pool: If set, a max pool layer is added after the first
convolutional layer
proj_dim: The number of output channels for the convolutional layers
l2_norm: If set, l2 normalization is applied to the extracted features
add_masks: If set, the masks will be saved in the output dictionary
add_images: If set, the images will be saved in the output dictionary
global_average_pool: If set, global average pooling step is performed
feature_rescale: If not 1.0, this rescale factor will be used to
rescale the output features
"""
name: str = "resnet34"
pretrained: bool = True
stages: Tuple[int, ...] = (1, 2, 3, 4)
normalize_image: bool = True
image_rescale: float = 128 / 800.0
first_max_pool: bool = True
proj_dim: int = 32
l2_norm: bool = True
add_masks: bool = True
add_images: bool = True
global_average_pool: bool = False # this can simulate global/non-spacial features
feature_rescale: float = 1.0
def __post_init__(self):
super().__init__()
if self.normalize_image:
# register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(
k,
torch.FloatTensor(v).view(1, 3, 1, 1),
persistent=False,
)
self._feat_dim = {}
if len(self.stages) == 0:
# do not extract any resnet features
pass
else:
net = getattr(torchvision.models, self.name)(pretrained=self.pretrained)
if self.first_max_pool:
self.stem = torch.nn.Sequential(
net.conv1, net.bn1, net.relu, net.maxpool
)
else:
self.stem = torch.nn.Sequential(net.conv1, net.bn1, net.relu)
self.max_stage = max(self.stages)
self.layers = torch.nn.ModuleList()
self.proj_layers = torch.nn.ModuleList()
for stage in range(self.max_stage):
stage_name = f"layer{stage+1}"
feature_name = self._get_resnet_stage_feature_name(stage)
if (stage + 1) in self.stages:
if (
self.proj_dim > 0
and _FEAT_DIMS[self.name][stage] > self.proj_dim
):
proj = torch.nn.Conv2d(
_FEAT_DIMS[self.name][stage],
self.proj_dim,
1,
1,
bias=True,
)
self._feat_dim[feature_name] = self.proj_dim
else:
proj = torch.nn.Identity()
self._feat_dim[feature_name] = _FEAT_DIMS[self.name][stage]
else:
proj = torch.nn.Identity()
self.proj_layers.append(proj)
self.layers.append(getattr(net, stage_name))
if self.add_masks:
self._feat_dim[MASK_FEATURE_NAME] = 1
if self.add_images:
self._feat_dim[IMAGE_FEATURE_NAME] = 3
logger.info(f"Feat extractor total dim = {self.get_feat_dims()}")
self.stages = set(self.stages) # convert to set for faster "in"
def _get_resnet_stage_feature_name(self, stage) -> str:
return f"res_layer_{stage+1}"
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self, size_dict: bool = False):
if size_dict:
return copy.deepcopy(self._feat_dim)
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `values`.
return sum(self._feat_dim.values())
def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
) -> Dict[Any, torch.Tensor]:
"""
Args:
imgs: A batch of input images of shape `(B, 3, H, W)`.
masks: A batch of input masks of shape `(B, 3, H, W)`.
Returns:
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
"""
out_feats = {}
imgs_input = imgs
if self.image_rescale != 1.0:
imgs_resized = Fu.interpolate(
imgs_input,
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
# got `float`.
scale_factor=self.image_rescale,
mode="bilinear",
)
else:
imgs_resized = imgs_input
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
if len(self.stages) > 0:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
# is not a function.
feats = self.stem(imgs_normed)
# pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T1]]` but
# got `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd param expected `Iterable[Variable[_T2]]` but
# got `Union[Tensor, Module]`.
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
feats = layer(feats)
# just a sanity check below
assert feats.shape[1] == _FEAT_DIMS[self.name][stage]
if (stage + 1) in self.stages:
f = proj(feats)
if self.global_average_pool:
f = f.mean(dims=(2, 3))
if self.l2_norm:
normfac = 1.0 / math.sqrt(len(self.stages))
f = Fu.normalize(f, dim=1) * normfac
feature_name = self._get_resnet_stage_feature_name(stage)
out_feats[feature_name] = f
if self.add_masks:
assert masks is not None
out_feats[MASK_FEATURE_NAME] = masks
if self.add_images:
assert imgs_input is not None
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
if self.feature_rescale != 1.0:
out_feats = {k: self.feature_rescale * f for k, f in out_feats.items()}
# pyre-fixme[7]: Incompatible return type, expected `Dict[typing.Any, Tensor]`
# but got `Dict[typing.Any, float]`
return out_feats
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
cameras_points_cartesian_product,
)
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase
class ReductionFunction(Enum):
AVG = "avg" # simple average
MAX = "max" # maximum
STD = "std" # standard deviation
STD_AVG = "std_avg" # average of per-dimension standard deviations
class FeatureAggregatorBase(ABC, ReplaceableBase):
"""
Base class for aggregating features.
Typically, the aggregated features and their masks are output by `ViewSampler`
which samples feature tensors extracted from a set of source images.
Settings:
exclude_target_view: If `True`/`False`, enables/disables pooling
from target view to itself.
exclude_target_view_mask_features: If `True`,
mask the features from the target view before aggregation
concatenate_output: If `True`,
concatenate the aggregated features into a single tensor,
otherwise return a dictionary mapping feature names to tensors.
"""
exclude_target_view: bool = True
exclude_target_view_mask_features: bool = True
concatenate_output: bool = True
@abstractmethod
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, reduce_dim, n_samples, sum(dim_1, ... dim_N))`
containing the concatenation of the aggregated features `feats_sampled`.
`reduce_dim` depends on the specific feature aggregator
implementation and typically equals 1 or `n_source_views`.
If `concatenate_output==False`, the aggregator does not concatenate
the aggregated features and returns a dictionary of per-feature
aggregations `{f_i: t_i_aggregated}` instead. Each `t_i_aggregated`
is of shape `(minibatch, reduce_dim, n_samples, aggr_dim_i)`.
"""
raise NotImplementedError()
@registry.register
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
This aggregator does not perform any feature aggregation. Depending on the
settings the aggregator allows to mask target view features and concatenate
the outputs.
"""
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
feats_aggregated = feats_sampled
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
Aggregates using a set of predefined `reduction_functions` and concatenates
the results of each aggregation function along the
channel dimension. The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s` that reduce the
the stack of source-view-specific features to a single feature.
"""
reduction_functions: Sequence[ReductionFunction] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape `(minibatch, 1, n_samples, aggr_dim_i)`.
"""
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weigths = masks_sampled * sampling_mask
feats_aggregated = {
k: _avgmaxstd_reduction_function(
f,
aggr_weigths,
dim=1,
reduction_functions=self.reduction_functions,
)
for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
Performs a weighted aggregation using a set of predefined `reduction_functions`
and concatenates the results of each aggregation function along the
channel dimension. The weights are proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s that reduce the
the stack of source-view-specific features to a single feature.
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
"""
reduction_functions: Sequence[ReductionFunction] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if camera is None:
raise ValueError("camera cannot be None for angle weighted aggregation")
if pts is None:
raise ValueError("Points cannot be None for angle weighted aggregation")
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weights = _get_angular_reduction_weights(
view_sampling_mask,
masks_sampled,
camera,
pts,
self.min_ray_angle_weight,
self.weight_by_ray_angle_gamma,
)
assert torch.isfinite(aggr_weights).all()
feats_aggregated = {
k: _avgmaxstd_reduction_function(
f,
aggr_weights,
dim=1,
reduction_functions=self.reduction_functions,
)
for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
This aggregator does not perform any feature aggregation. It only weights
the features by the weights proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
Settings:
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
Additionally the aggregator allows to mask target view features and to concatenate
the outputs.
"""
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, n_source_views, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if camera is None:
raise ValueError("camera cannot be None for angle weighted aggregation")
if pts is None:
raise ValueError("Points cannot be None for angle weighted aggregation")
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weights = _get_angular_reduction_weights(
view_sampling_mask,
masks_sampled,
camera,
pts,
self.min_ray_angle_weight,
self.weight_by_ray_angle_gamma,
)
feats_aggregated = {
k: f * aggr_weights[..., None] for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
def _get_reduction_aggregator_feature_dim(
feats_or_feats_dim: Union[Dict[str, torch.Tensor], int],
reduction_functions: Sequence[ReductionFunction],
):
if isinstance(feats_or_feats_dim, int):
feat_dim = feats_or_feats_dim
else:
feat_dim = int(sum(f.shape[1] for f in feats_or_feats_dim.values()))
if len(reduction_functions) == 0:
return feat_dim
return sum(
_get_reduction_function_output_dim(
reduction_function,
feat_dim,
)
for reduction_function in reduction_functions
)
def _get_reduction_function_output_dim(
reduction_function: ReductionFunction,
feat_dim: int,
) -> int:
if reduction_function == ReductionFunction.STD_AVG:
return 1
else:
return feat_dim
def _get_view_sampling_mask(
n_cameras: int,
pts_batch: int,
device: Union[str, torch.device],
exclude_target_view: bool,
):
return (
-torch.eye(n_cameras, device=device, dtype=torch.float32)
* float(exclude_target_view)
+ 1.0
)[:pts_batch]
def _mask_target_view_features(
feats_sampled: Dict[str, torch.Tensor],
):
# mask out the sampled features to be sure we dont use them
# anywhere later
one_feature_sampled = next(iter(feats_sampled.values()))
pts_batch, n_cameras = one_feature_sampled.shape[:2]
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
one_feature_sampled.device,
True,
)
view_sampling_mask = view_sampling_mask.view(
pts_batch, n_cameras, *([1] * (one_feature_sampled.ndim - 2))
)
return {k: f * view_sampling_mask for k, f in feats_sampled.items()}
def _get_angular_reduction_weights(
view_sampling_mask: torch.Tensor,
masks_sampled: torch.Tensor,
camera: CamerasBase,
pts: torch.Tensor,
min_ray_angle_weight: float,
weight_by_ray_angle_gamma: float,
):
aggr_weights = masks_sampled.clone()[..., 0]
assert not any(v is None for v in [camera, pts])
angle_weight = _get_ray_angle_weights(
camera,
pts,
min_ray_angle_weight,
weight_by_ray_angle_gamma,
)
assert torch.isfinite(angle_weight).all()
# multiply the final aggr weights with ray angles
view_sampling_mask = view_sampling_mask.view(
*view_sampling_mask.shape[:2], *([1] * (aggr_weights.ndim - 2))
)
aggr_weights = (
aggr_weights * angle_weight.reshape_as(aggr_weights) * view_sampling_mask
)
return aggr_weights
def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
n_cameras = camera.R.shape[0]
pts_batch = pts.shape[0]
camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts)
# does not produce nans randomly unlike get_camera_center() below
cam_centers_rep = -torch.bmm(
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep.T[:, None],
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep.R.permute(0, 2, 1),
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
# cam_centers_rep = camera_rep.get_camera_center().reshape(
# -1, *([1]*(pts.ndim - 2)), 3
# )
ray_dirs = F.normalize(pts_rep - cam_centers_rep, dim=-1)
# camera_rep = [ pts_rep = [
# camera[0] pts[0],
# camera[0] pts[1],
# camera[0] ...,
# ... pts[batch_pts-1],
# camera[1] pts[0],
# camera[1] pts[1],
# camera[1] ...,
# ... pts[batch_pts-1],
# ... ...,
# camera[n_cameras-1] pts[0],
# camera[n_cameras-1] pts[1],
# camera[n_cameras-1] ...,
# ... pts[batch_pts-1],
# ] ]
ray_dirs_reshape = ray_dirs.view(n_cameras, pts_batch, -1, 3)
# [
# [pts_0 in cam_0, pts_1 in cam_0, ..., pts_m in cam_0],
# [pts_0 in cam_1, pts_1 in cam_1, ..., pts_m in cam_1],
# ...
# [pts_0 in cam_n, pts_1 in cam_n, ..., pts_m in cam_n],
# ]
ray_dirs_pts = torch.stack([ray_dirs_reshape[i, i] for i in range(pts_batch)])
ray_dir_dot_prods = (ray_dirs_pts[None] * ray_dirs_reshape).sum(
dim=-1
) # pts_batch x n_cameras x n_pts
return ray_dir_dot_prods.transpose(0, 1)
def _get_ray_angle_weights(
camera: CamerasBase,
pts: torch.Tensor,
min_ray_angle_weight: float,
weight_by_ray_angle_gamma: float,
):
ray_dir_dot_prods = _get_ray_dir_dot_prods(
camera, pts
) # pts_batch x n_cameras x ... x 3
angle_weight_01 = ray_dir_dot_prods * 0.5 + 0.5 # [-1, 1] to [0, 1]
angle_weight = (angle_weight_01 + min_ray_angle_weight) ** weight_by_ray_angle_gamma
return angle_weight
def _avgmaxstd_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
reduction_functions: Sequence[ReductionFunction],
dim: int = 1,
):
"""
Args:
x: Features to aggreagate. Tensor of shape `(batch, n_views, ..., dim)`.
w: Aggregation weights. Tensor of shape `(batch, n_views, ...,)`.
dim: the dimension along which to aggregate.
reduction_functions: The set of reduction functions.
Returns:
x_aggr: Aggregation of `x` to a tensor of shape `(batch, 1, ..., dim_aggregate)`.
"""
pooled_features = []
mu = None
std = None
if ReductionFunction.AVG in reduction_functions:
# average pool
mu = _avg_reduction_function(x, w, dim=dim)
pooled_features.append(mu)
if ReductionFunction.STD in reduction_functions:
# standard-dev pool
std = _std_reduction_function(x, w, dim=dim, mu=mu)
pooled_features.append(std)
if ReductionFunction.STD_AVG in reduction_functions:
# average-of-standard-dev pool
stdavg = _std_avg_reduction_function(x, w, dim=dim, mu=mu, std=std)
pooled_features.append(stdavg)
if ReductionFunction.MAX in reduction_functions:
max_ = _max_reduction_function(x, w, dim=dim)
pooled_features.append(max_)
# cat all results along the feature dimension (the last dim)
x_aggr = torch.cat(pooled_features, dim=-1)
# zero out features that were all masked out
any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr)
x_aggr = x_aggr * any_active[..., None]
# some asserts to check that everything was done right
assert torch.isfinite(x_aggr).all()
assert x_aggr.shape[1] == 1
return x_aggr
def _avg_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
):
mu = wmean(x, w, dim=dim, eps=1e-2)
return mu
def _std_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
mu: Optional[torch.Tensor] = None, # pre-computed mean
):
if mu is None:
mu = _avg_reduction_function(x, w, dim=dim)
std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt()
# FIXME: somehow this is extremely heavy in mem?
return std
def _std_avg_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
mu: Optional[torch.Tensor] = None, # pre-computed mean
std: Optional[torch.Tensor] = None, # pre-computed std
):
if std is None:
std = _std_reduction_function(x, w, dim=dim, mu=mu)
stdmean = std.mean(dim=-1, keepdim=True)
return stdmean
def _max_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
big_M_factor: float = 10.0,
):
big_M = x.max(dim=dim, keepdim=True).values.abs() * big_M_factor
max_ = (x * w - ((1 - w) * big_M)).max(dim=dim, keepdim=True).values
return max_
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.utils import ndc_grid_sample
class ViewSampler(Configurable, torch.nn.Module):
"""
Implements sampling of image-based features at the 2d projections of a set
of 3D points.
Args:
masked_sampling: If `True`, the `sampled_masks` output of `self.forward`
contains the input `masks` sampled at the 2d projections. Otherwise,
all entries of `sampled_masks` are set to 1.
sampling_mode: Controls the mode of the `torch.nn.functional.grid_sample`
function used to interpolate the sampled feature tensors at the
locations of the 2d projections.
"""
masked_sampling: bool = False
sampling_mode: str = "bilinear"
def __post_init__(self):
super().__init__()
def forward(
self,
*, # force kw args
pts: torch.Tensor,
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
camera: CamerasBase,
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
feats: Dict[str, torch.Tensor],
masks: Optional[torch.Tensor],
**kwargs,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Project each point cloud from a batch of point clouds to corresponding
input cameras and sample features at the 2D projection locations.
Args:
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
from which `pts` were extracted, or a list of string names.
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
corresponding to cameras in `camera`, or a list of string names.
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
masks: `[n_cameras x 1 x H x W]`, define valid image regions
for sampling `feats`.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` of shape `[pts_batch, n_cameras, n_pts, dim_i]`.
sampled_masks: A tensor with mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
"""
# convert sequence ids to long tensors
seq_id_pts, seq_id_camera = [
handle_seq_id(seq_id, pts.device) for seq_id in [seq_id_pts, seq_id_camera]
]
if self.masked_sampling and masks is None:
raise ValueError(
"Masks have to be provided for `self.masked_sampling==True`"
)
# project pts to all cameras and sample feats from the locations of
# the 2D projections
sampled_feats_all_cams, sampled_masks_all_cams = project_points_and_sample(
pts,
feats,
camera,
masks if self.masked_sampling else None,
sampling_mode=self.sampling_mode,
)
# generate the mask that invalidates features sampled from
# non-corresponding cameras
camera_pts_mask = (seq_id_camera[None] == seq_id_pts[:, None])[
..., None, None
].to(pts)
# mask the sampled features and masks
sampled_feats = {
k: f * camera_pts_mask for k, f in sampled_feats_all_cams.items()
}
sampled_masks = sampled_masks_all_cams * camera_pts_mask
return sampled_feats, sampled_masks
def project_points_and_sample(
pts: torch.Tensor,
feats: Dict[str, torch.Tensor],
camera: CamerasBase,
masks: Optional[torch.Tensor],
eps: float = 1e-2,
sampling_mode: str = "bilinear",
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Project each point cloud from a batch of point clouds to all input cameras
and sample features at the 2D projection locations.
Args:
pts: `(pts_batch, n_pts, 3)` tensor containing a batch of 3D point clouds.
feats: A dict `{feat_i: feat_T_i}` of features to sample,
where each `feat_T_i` is a tensor of shape
`(n_cameras, feat_i_dim, feat_i_H, feat_i_W)`
of `feat_i_dim`-dimensional features extracted from `n_cameras`
source views.
camera: A batch of `n_cameras` cameras corresponding to their feature
tensors `feat_T_i` from `feats`.
masks: A tensor of shape `(n_cameras, 1, mask_H, mask_W)` denoting
valid locations for sampling.
eps: A small constant controlling the minimum depth of projections
of `pts` to avoid divisons by zero in the projection operation.
sampling_mode: Sampling mode of the grid sampler.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` is of shape
`(pts_batch, n_cameras, n_pts, feat_i_dim)`.
sampled_masks: A tensor with the mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
If `masks` is `None`, the returned `sampled_masks` will be
filled with 1s.
"""
n_cameras = camera.R.shape[0]
pts_batch = pts.shape[0]
n_pts = pts.shape[1:-1]
camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts)
# The eps here is super-important to avoid NaNs in backprop!
proj_rep = camera_rep.transform_points(
pts_rep.reshape(n_cameras * pts_batch, -1, 3), eps=eps
)[..., :2]
# [ pts1 in cam1, pts2 in cam1, pts3 in cam1,
# pts1 in cam2, pts2 in cam2, pts3 in cam2,
# pts1 in cam3, pts2 in cam3, pts3 in cam3 ]
# reshape for the grid sampler
sampling_grid_ndc = proj_rep.view(n_cameras, pts_batch, -1, 2)
# [ [pts1 in cam1, pts2 in cam1, pts3 in cam1],
# [pts1 in cam2, pts2 in cam2, pts3 in cam2],
# [pts1 in cam3, pts2 in cam3, pts3 in cam3] ]
# n_cameras x pts_batch x n_pts x 2
# sample both feats
feats_sampled = {
k: ndc_grid_sample(
f,
sampling_grid_ndc,
mode=sampling_mode,
align_corners=False,
)
.permute(2, 0, 3, 1)
.reshape(pts_batch, n_cameras, *n_pts, -1)
for k, f in feats.items()
} # {k: pts_batch x n_cameras x *n_pts x dim} for each feat type "k"
if masks is not None:
# sample masks
masks_sampled = (
ndc_grid_sample(
masks,
sampling_grid_ndc,
mode=sampling_mode,
align_corners=False,
)
.permute(2, 0, 3, 1)
.reshape(pts_batch, n_cameras, *n_pts, 1)
)
else:
masks_sampled = sampling_grid_ndc.new_ones(pts_batch, n_cameras, *n_pts, 1)
return feats_sampled, masks_sampled
def handle_seq_id(
seq_id: Union[torch.LongTensor, List[str], List[int]],
device,
) -> torch.LongTensor:
"""
Converts the input sequence id to a LongTensor.
Args:
seq_id: A sequence of sequence ids.
device: The target device of the output.
Returns
long_seq_id: `seq_id` converted to a `LongTensor` and moved to `device`.
"""
if not torch.is_tensor(seq_id):
if isinstance(seq_id[0], str):
seq_id = [hash(s) for s in seq_id]
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
return seq_id.to(device)
def cameras_points_cartesian_product(
camera: CamerasBase, pts: torch.Tensor
) -> Tuple[CamerasBase, torch.Tensor]:
"""
Generates all pairs of pairs of elements from 'camera' and 'pts' and returns
`camera_rep` and `pts_rep` such that:
```
camera_rep = [ pts_rep = [
camera[0] pts[0],
camera[0] pts[1],
camera[0] ...,
... pts[batch_pts-1],
camera[1] pts[0],
camera[1] pts[1],
camera[1] ...,
... pts[batch_pts-1],
... ...,
camera[n_cameras-1] pts[0],
camera[n_cameras-1] pts[1],
camera[n_cameras-1] ...,
... pts[batch_pts-1],
] ]
```
Args:
camera: A batch of `n_cameras` cameras.
pts: A batch of `batch_pts` points of shape `(batch_pts, ..., dim)`
Returns:
camera_rep: A batch of batch_pts*n_cameras cameras such that:
```
camera_rep = [
camera[0]
camera[0]
camera[0]
...
camera[1]
camera[1]
camera[1]
...
...
camera[n_cameras-1]
camera[n_cameras-1]
camera[n_cameras-1]
]
```
pts_rep: Repeated `pts` of shape `(batch_pts*n_cameras, ..., dim)`,
such that:
```
pts_rep = [
pts[0],
pts[1],
...,
pts[batch_pts-1],
pts[0],
pts[1],
...,
pts[batch_pts-1],
...,
pts[0],
pts[1],
...,
pts[batch_pts-1],
]
```
"""
n_cameras = camera.R.shape[0]
batch_pts = pts.shape[0]
pts_rep = pts.repeat(n_cameras, *[1 for _ in pts.shape[1:]])
idx_cams = (
torch.arange(n_cameras)[:, None]
.expand(
n_cameras,
batch_pts,
)
.reshape(batch_pts * n_cameras)
)
camera_rep = camera[idx_cams]
return camera_rep, pts_rep
# a copy-paste from https://github.com/vsitzmann/scene-representation-networks/blob/master/hyperlayers.py
# fmt: off
# flake8: noqa
'''Pytorch implementations of hyper-network modules.
isort:skip_file
'''
import functools
import torch
import torch.nn as nn
from . import pytorch_prototyping
def partialclass(cls, *args, **kwds):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwds)
return NewCls
class LookupLayer(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.out_ch = out_ch
self.lookup_lin = LookupLinear(in_ch, out_ch, num_objects=num_objects)
self.norm_nl = nn.Sequential(
nn.LayerNorm([self.out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, obj_idx):
net = nn.Sequential(self.lookup_lin(obj_idx), self.norm_nl)
return net
class LookupFC(nn.Module):
def __init__(
self,
hidden_ch,
num_hidden_layers,
num_objects,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)
)
for i in range(num_hidden_layers):
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)
)
if outermost_linear:
self.layers.append(
LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
else:
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
def forward(self, obj_idx):
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](obj_idx))
return nn.Sequential(*net)
class LookupLinear(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch)
for i in range(num_objects):
nn.init.kaiming_normal_(
self.hypo_params.weight.data[i, : self.in_ch * self.out_ch].view(
self.out_ch, self.in_ch
),
a=0.0,
nonlinearity="relu",
mode="fan_in",
)
self.hypo_params.weight.data[i, self.in_ch * self.out_ch :].fill_(0.0)
def forward(self, obj_idx):
hypo_params = self.hypo_params(obj_idx)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)
class HyperLayer(nn.Module):
"""A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.hyper_linear = HyperLinear(
in_ch=in_ch,
out_ch=out_ch,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.norm_nl = nn.Sequential(
nn.LayerNorm([out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, hyper_input):
"""
:param hyper_input: input to hypernetwork.
:return: nn.Module; predicted fully connected network.
"""
return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl)
class HyperFC(nn.Module):
"""Builds a hypernetwork that predicts a fully connected neural network."""
def __init__(
self,
hyper_in_ch,
hyper_num_hidden_layers,
hyper_hidden_ch,
hidden_ch,
num_hidden_layers,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
PreconfHyperLinear = partialclass(
HyperLinear,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
PreconfHyperLayer = partialclass(
HyperLayer,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.layers = nn.ModuleList()
self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch))
for i in range(num_hidden_layers):
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch))
if outermost_linear:
self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch))
else:
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch))
def forward(self, hyper_input):
"""
:param hyper_input: Input to hypernetwork.
:return: nn.Module; Predicted fully connected neural network.
"""
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](hyper_input))
return nn.Sequential(*net)
class BatchLinear(nn.Module):
def __init__(self, weights, biases):
"""Implements a batch linear layer.
:param weights: Shape: (batch, out_ch, in_ch)
:param biases: Shape: (batch, 1, out_ch)
"""
super().__init__()
self.weights = weights
self.biases = biases
def __repr__(self):
return "BatchLinear(in_ch=%d, out_ch=%d)" % (
self.weights.shape[-1],
self.weights.shape[-2],
)
def forward(self, input):
output = input.matmul(
self.weights.permute(
*[i for i in range(len(self.weights.shape) - 2)], -1, -2
)
)
output += self.biases
return output
def last_hyper_layer_init(m) -> None:
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
# pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property.
m.weight.data *= 1e-1
class HyperLinear(nn.Module):
"""A hypernetwork that predicts a single linear layer (weights & biases)."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = pytorch_prototyping.FCBlock(
in_features=hyper_in_ch,
hidden_ch=hyper_hidden_ch,
num_hidden_layers=hyper_num_hidden_layers,
out_features=(in_ch * out_ch) + out_ch,
outermost_linear=True,
)
self.hypo_params[-1].apply(last_hyper_layer_init)
def forward(self, hyper_input):
hypo_params = self.hypo_params(hyper_input)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)
# a copy-paste from https://raw.githubusercontent.com/vsitzmann/pytorch_prototyping/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py
# fmt: off
# flake8: noqa
'''A number of custom pytorch modules with sane defaults that I find useful for model prototyping.
isort:skip_file
'''
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn import functional as F
class FCLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_features, out_features),
nn.LayerNorm([out_features]),
nn.ReLU(inplace=True),
)
def forward(self, input):
return self.net(input)
# From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02
class LayerNormConv2d(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super().__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
y = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
y = self.gamma.view(*shape) * y + self.beta.view(*shape)
return y
class FCBlock(nn.Module):
def __init__(
self,
hidden_ch,
num_hidden_layers,
in_features,
out_features,
outermost_linear=False,
):
super().__init__()
self.net = []
self.net.append(FCLayer(in_features=in_features, out_features=hidden_ch))
for i in range(num_hidden_layers):
self.net.append(FCLayer(in_features=hidden_ch, out_features=hidden_ch))
if outermost_linear:
self.net.append(nn.Linear(in_features=hidden_ch, out_features=out_features))
else:
self.net.append(FCLayer(in_features=hidden_ch, out_features=out_features))
self.net = nn.Sequential(*self.net)
self.net.apply(self.init_weights)
def __getitem__(self, item):
return self.net[item]
def init_weights(self, m):
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
def forward(self, input):
return self.net(input)
class DownBlock3D(nn.Module):
"""A 3D convolutional downsampling block."""
def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d):
super().__init__()
self.net = [
nn.ReplicationPad3d(1),
nn.Conv3d(
in_channels,
out_channels,
kernel_size=4,
padding=0,
stride=2,
bias=False if norm is not None else True,
),
]
if norm is not None:
self.net += [norm(out_channels, affine=True)]
self.net += [nn.LeakyReLU(0.2, True)]
self.net = nn.Sequential(*self.net)
def forward(self, x):
return self.net(x)
class UpBlock3D(nn.Module):
"""A 3D convolutional upsampling block."""
def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d):
super().__init__()
self.net = [
nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False if norm is not None else True,
),
]
if norm is not None:
self.net += [norm(out_channels, affine=True)]
self.net += [nn.ReLU(True)]
self.net = nn.Sequential(*self.net)
def forward(self, x, skipped=None):
if skipped is not None:
input = torch.cat([skipped, x], dim=1)
else:
input = x
return self.net(input)
class Conv3dSame(torch.nn.Module):
"""3D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
bias=True,
padding_layer=nn.ReplicationPad3d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super().__init__()
ka = kernel_size // 2
kb = ka - 1 if kernel_size % 2 == 0 else ka
self.net = nn.Sequential(
padding_layer((ka, kb, ka, kb, ka, kb)),
nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, stride=1),
)
def forward(self, x):
return self.net(x)
class Conv2dSame(torch.nn.Module):
"""2D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
bias=True,
padding_layer=nn.ReflectionPad2d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super().__init__()
ka = kernel_size // 2
kb = ka - 1 if kernel_size % 2 == 0 else ka
self.net = nn.Sequential(
padding_layer((ka, kb, ka, kb)),
nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1),
)
self.weight = self.net[1].weight
self.bias = self.net[1].bias
def forward(self, x):
return self.net(x)
class UpBlock(nn.Module):
"""A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with
reasonable defaults. (LeakyReLU, kernel size multiple of stride)
"""
def __init__(
self,
in_channels,
out_channels,
post_conv=True,
use_dropout=False,
dropout_prob=0.1,
norm=nn.BatchNorm2d,
upsampling_mode="transpose",
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param post_conv: Whether to have another convolutional layer after the upsampling layer.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param upsampling_mode: Which upsampling mode:
transpose: Upsampling with stride-2, kernel size 4 transpose convolutions.
bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer.
nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer.
shuffle: Feature map is upsampled with pixel shuffling, then a conv layer.
"""
super().__init__()
net = list()
if upsampling_mode == "transpose":
net += [
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "bilinear":
net += [nn.UpsamplingBilinear2d(scale_factor=2)]
net += [
Conv2dSame(
in_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "nearest":
net += [nn.UpsamplingNearest2d(scale_factor=2)]
net += [
Conv2dSame(
in_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "shuffle":
net += [nn.PixelShuffle(upscale_factor=2)]
net += [
Conv2dSame(
in_channels // 4,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
else:
raise ValueError("Unknown upsampling mode!")
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.ReLU(True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
if post_conv:
net += [
Conv2dSame(
out_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.ReLU(True)]
if use_dropout:
net += [nn.Dropout2d(0.1, False)]
self.net = nn.Sequential(*net)
def forward(self, x, skipped=None):
if skipped is not None:
input = torch.cat([skipped, x], dim=1)
else:
input = x
return self.net(input)
class DownBlock(nn.Module):
"""A 2D-conv downsampling block following best practices / with reasonable defaults
(LeakyReLU, kernel size multiple of stride)
"""
def __init__(
self,
in_channels,
out_channels,
prep_conv=True,
middle_channels=None,
use_dropout=False,
dropout_prob=0.1,
norm=nn.BatchNorm2d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param prep_conv: Whether to have another convolutional layer before the downsampling layer.
:param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling
convs.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
"""
super().__init__()
if middle_channels is None:
middle_channels = in_channels
net = list()
if prep_conv:
net += [
nn.ReflectionPad2d(1),
nn.Conv2d(
in_channels,
middle_channels,
kernel_size=3,
padding=0,
stride=1,
bias=True if norm is None else False,
),
]
if norm is not None:
net += [norm(middle_channels, affine=True)]
net += [nn.LeakyReLU(0.2, True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
net += [
nn.ReflectionPad2d(1),
nn.Conv2d(
middle_channels,
out_channels,
kernel_size=4,
padding=0,
stride=2,
bias=True if norm is None else False,
),
]
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.LeakyReLU(0.2, True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
self.net = nn.Sequential(*net)
def forward(self, x):
return self.net(x)
class Unet3d(nn.Module):
"""A 3d-Unet implementation with sane defaults."""
def __init__(
self,
in_channels,
out_channels,
nf0,
num_down,
max_channels,
norm=nn.BatchNorm3d,
outermost_linear=False,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super().__init__()
assert num_down > 0, "Need at least one downsampling layer in UNet3d."
# Define the in block
self.in_layer = [Conv3dSame(in_channels, nf0, kernel_size=3, bias=False)]
if norm is not None:
self.in_layer += [norm(nf0, affine=True)]
self.in_layer += [nn.LeakyReLU(0.2, True)]
self.in_layer = nn.Sequential(*self.in_layer)
# Define the center UNet block. The feature map has height and width 1 --> no batchnorm.
self.unet_block = UnetSkipConnectionBlock3d(
int(min(2 ** (num_down - 1) * nf0, max_channels)),
int(min(2 ** (num_down - 1) * nf0, max_channels)),
norm=None,
)
for i in list(range(0, num_down - 1))[::-1]:
self.unet_block = UnetSkipConnectionBlock3d(
int(min(2 ** i * nf0, max_channels)),
int(min(2 ** (i + 1) * nf0, max_channels)),
submodule=self.unet_block,
norm=norm,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self.out_layer = [
Conv3dSame(2 * nf0, out_channels, kernel_size=3, bias=outermost_linear)
]
if not outermost_linear:
if norm is not None:
self.out_layer += [norm(out_channels, affine=True)]
self.out_layer += [nn.ReLU(True)]
self.out_layer = nn.Sequential(*self.out_layer)
def forward(self, x):
in_layer = self.in_layer(x)
unet = self.unet_block(in_layer)
out_layer = self.out_layer(unet)
return out_layer
class UnetSkipConnectionBlock3d(nn.Module):
"""Helper class for building a 3D unet."""
def __init__(self, outer_nc, inner_nc, norm=nn.BatchNorm3d, submodule=None):
super().__init__()
if submodule is None:
model = [
DownBlock3D(outer_nc, inner_nc, norm=norm),
UpBlock3D(inner_nc, outer_nc, norm=norm),
]
else:
model = [
DownBlock3D(outer_nc, inner_nc, norm=norm),
submodule,
UpBlock3D(2 * inner_nc, outer_nc, norm=norm),
]
self.model = nn.Sequential(*model)
def forward(self, x):
forward_passed = self.model(x)
return torch.cat([x, forward_passed], 1)
class UnetSkipConnectionBlock(nn.Module):
"""Helper class for building a 2D unet."""
def __init__(
self,
outer_nc,
inner_nc,
upsampling_mode,
norm=nn.BatchNorm2d,
submodule=None,
use_dropout=False,
dropout_prob=0.1,
):
super().__init__()
if submodule is None:
model = [
DownBlock(
outer_nc,
inner_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
),
UpBlock(
inner_nc,
outer_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
),
]
else:
model = [
DownBlock(
outer_nc,
inner_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
),
submodule,
UpBlock(
2 * inner_nc,
outer_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
),
]
self.model = nn.Sequential(*model)
def forward(self, x):
forward_passed = self.model(x)
return torch.cat([x, forward_passed], 1)
class Unet(nn.Module):
"""A 2d-Unet implementation with sane defaults."""
def __init__(
self,
in_channels,
out_channels,
nf0,
num_down,
max_channels,
use_dropout,
upsampling_mode="transpose",
dropout_prob=0.1,
norm=nn.BatchNorm2d,
outermost_linear=False,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param use_dropout: Whether to use dropout or no.
:param dropout_prob: Dropout probability if use_dropout=True.
:param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation.
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super().__init__()
assert num_down > 0, "Need at least one downsampling layer in UNet."
# Define the in block
self.in_layer = [
Conv2dSame(
in_channels, nf0, kernel_size=3, bias=True if norm is None else False
)
]
if norm is not None:
self.in_layer += [norm(nf0, affine=True)]
self.in_layer += [nn.LeakyReLU(0.2, True)]
if use_dropout:
self.in_layer += [nn.Dropout2d(dropout_prob)]
self.in_layer = nn.Sequential(*self.in_layer)
# Define the center UNet block
self.unet_block = UnetSkipConnectionBlock(
min(2 ** (num_down - 1) * nf0, max_channels),
min(2 ** (num_down - 1) * nf0, max_channels),
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=None, # Innermost has no norm (spatial dimension 1)
upsampling_mode=upsampling_mode,
)
for i in list(range(0, num_down - 1))[::-1]:
self.unet_block = UnetSkipConnectionBlock(
min(2 ** i * nf0, max_channels),
min(2 ** (i + 1) * nf0, max_channels),
use_dropout=use_dropout,
dropout_prob=dropout_prob,
submodule=self.unet_block,
norm=norm,
upsampling_mode=upsampling_mode,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self.out_layer = [
Conv2dSame(
2 * nf0,
out_channels,
kernel_size=3,
bias=outermost_linear or (norm is None),
)
]
if not outermost_linear:
if norm is not None:
self.out_layer += [norm(out_channels, affine=True)]
self.out_layer += [nn.ReLU(True)]
if use_dropout:
self.out_layer += [nn.Dropout2d(dropout_prob)]
self.out_layer = nn.Sequential(*self.out_layer)
self.out_layer_weight = self.out_layer[0].weight
def forward(self, x):
in_layer = self.in_layer(x)
unet = self.unet_block(in_layer)
out_layer = self.out_layer(unet)
return out_layer
class Identity(nn.Module):
"""Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list."""
def __init__(self):
super().__init__()
def forward(self, input):
return input
class DownsamplingNet(nn.Module):
"""A subnetwork that downsamples a 2D feature map with strided convolutions."""
def __init__(
self,
per_layer_out_ch,
in_channels,
use_dropout,
dropout_prob=0.1,
last_layer_one=False,
norm=nn.BatchNorm2d,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of downsampling steps (each step dowsamples by factor of 2.)
:param in_channels: Number of input channels.
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case,
the last layer will not have batchnorm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super().__init__()
if not len(per_layer_out_ch):
self.downs = Identity()
else:
self.downs = list()
self.downs.append(
DownBlock(
in_channels,
per_layer_out_ch[0],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
middle_channels=per_layer_out_ch[0],
norm=norm,
)
)
for i in range(0, len(per_layer_out_ch) - 1):
if last_layer_one and (i == len(per_layer_out_ch) - 2):
norm = None
self.downs.append(
DownBlock(
per_layer_out_ch[i],
per_layer_out_ch[i + 1],
dropout_prob=dropout_prob,
use_dropout=use_dropout,
norm=norm,
)
)
self.downs = nn.Sequential(*self.downs)
def forward(self, input):
return self.downs(input)
class UpsamplingNet(nn.Module):
"""A subnetwork that upsamples a 2D feature map with a variety of upsampling options."""
def __init__(
self,
per_layer_out_ch,
in_channels,
upsampling_mode,
use_dropout,
dropout_prob=0.1,
first_layer_one=False,
norm=nn.BatchNorm2d,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of upsampling steps (each step upsamples by factor of 2.)
:param in_channels: Number of input channels.
:param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock"
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case,
the first layer will not have a norm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super().__init__()
if not len(per_layer_out_ch):
self.ups = Identity()
else:
self.ups = list()
self.ups.append(
UpBlock(
in_channels,
per_layer_out_ch[0],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=None if first_layer_one else norm,
upsampling_mode=upsampling_mode,
)
)
for i in range(0, len(per_layer_out_ch) - 1):
self.ups.append(
UpBlock(
per_layer_out_ch[i],
per_layer_out_ch[i + 1],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
)
)
self.ups = nn.Sequential(*self.ups)
def forward(self, input):
return self.ups(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment