Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pytorch3d
Commits
cdd2142d
Unverified
Commit
cdd2142d
authored
Mar 21, 2022
by
Jeremy Reizenstein
Committed by
GitHub
Mar 21, 2022
Browse files
implicitron v0 (#1133)
Co-authored-by:
Jeremy Francis Reizenstein
<
bottler@users.noreply.github.com
>
parent
0e377c68
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5570 additions
and
0 deletions
+5570
-0
pytorch3d/implicitron/models/implicit_function/idr_feature_field.py
...implicitron/models/implicit_function/idr_feature_field.py
+142
-0
pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py
...icitron/models/implicit_function/neural_radiance_field.py
+542
-0
pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py
...models/implicit_function/scene_representation_networks.py
+411
-0
pytorch3d/implicitron/models/implicit_function/utils.py
pytorch3d/implicitron/models/implicit_function/utils.py
+90
-0
pytorch3d/implicitron/models/metrics.py
pytorch3d/implicitron/models/metrics.py
+230
-0
pytorch3d/implicitron/models/model_dbir.py
pytorch3d/implicitron/models/model_dbir.py
+139
-0
pytorch3d/implicitron/models/renderer/base.py
pytorch3d/implicitron/models/renderer/base.py
+118
-0
pytorch3d/implicitron/models/renderer/lstm_renderer.py
pytorch3d/implicitron/models/renderer/lstm_renderer.py
+179
-0
pytorch3d/implicitron/models/renderer/multipass_ea.py
pytorch3d/implicitron/models/renderer/multipass_ea.py
+171
-0
pytorch3d/implicitron/models/renderer/ray_point_refiner.py
pytorch3d/implicitron/models/renderer/ray_point_refiner.py
+87
-0
pytorch3d/implicitron/models/renderer/ray_sampler.py
pytorch3d/implicitron/models/renderer/ray_sampler.py
+190
-0
pytorch3d/implicitron/models/renderer/ray_tracing.py
pytorch3d/implicitron/models/renderer/ray_tracing.py
+573
-0
pytorch3d/implicitron/models/renderer/raymarcher.py
pytorch3d/implicitron/models/renderer/raymarcher.py
+143
-0
pytorch3d/implicitron/models/renderer/rgb_net.py
pytorch3d/implicitron/models/renderer/rgb_net.py
+101
-0
pytorch3d/implicitron/models/renderer/sdf_renderer.py
pytorch3d/implicitron/models/renderer/sdf_renderer.py
+253
-0
pytorch3d/implicitron/models/resnet_feature_extractor.py
pytorch3d/implicitron/models/resnet_feature_extractor.py
+218
-0
pytorch3d/implicitron/models/view_pooling/feature_aggregation.py
...3d/implicitron/models/view_pooling/feature_aggregation.py
+666
-0
pytorch3d/implicitron/models/view_pooling/view_sampling.py
pytorch3d/implicitron/models/view_pooling/view_sampling.py
+291
-0
pytorch3d/implicitron/third_party/hyperlayers.py
pytorch3d/implicitron/third_party/hyperlayers.py
+254
-0
pytorch3d/implicitron/third_party/pytorch_prototyping.py
pytorch3d/implicitron/third_party/pytorch_prototyping.py
+772
-0
No files found.
pytorch3d/implicitron/models/implicit_function/idr_feature_field.py
0 → 100644
View file @
cdd2142d
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import
math
from
typing
import
Sequence
import
torch
from
pytorch3d.implicitron.tools.config
import
registry
from
pytorch3d.renderer.implicit
import
HarmonicEmbedding
from
torch
import
nn
from
.base
import
ImplicitFunctionBase
@
registry
.
register
class
IdrFeatureField
(
ImplicitFunctionBase
,
torch
.
nn
.
Module
):
feature_vector_size
:
int
=
3
d_in
:
int
=
3
d_out
:
int
=
1
dims
:
Sequence
[
int
]
=
(
512
,
512
,
512
,
512
,
512
,
512
,
512
,
512
)
geometric_init
:
bool
=
True
bias
:
float
=
1.0
skip_in
:
Sequence
[
int
]
=
()
weight_norm
:
bool
=
True
n_harmonic_functions_xyz
:
int
=
0
pooled_feature_dim
:
int
=
0
encoding_dim
:
int
=
0
def
__post_init__
(
self
):
super
().
__init__
()
dims
=
[
self
.
d_in
]
+
list
(
self
.
dims
)
+
[
self
.
d_out
+
self
.
feature_vector_size
]
self
.
embed_fn
=
None
if
self
.
n_harmonic_functions_xyz
>
0
:
self
.
embed_fn
=
HarmonicEmbedding
(
self
.
n_harmonic_functions_xyz
,
append_input
=
True
)
dims
[
0
]
=
self
.
embed_fn
.
get_output_dim
()
if
self
.
pooled_feature_dim
>
0
:
dims
[
0
]
+=
self
.
pooled_feature_dim
if
self
.
encoding_dim
>
0
:
dims
[
0
]
+=
self
.
encoding_dim
self
.
num_layers
=
len
(
dims
)
out_dim
=
0
layers
=
[]
for
layer_idx
in
range
(
self
.
num_layers
-
1
):
if
layer_idx
+
1
in
self
.
skip_in
:
out_dim
=
dims
[
layer_idx
+
1
]
-
dims
[
0
]
else
:
out_dim
=
dims
[
layer_idx
+
1
]
lin
=
nn
.
Linear
(
dims
[
layer_idx
],
out_dim
)
if
self
.
geometric_init
:
if
layer_idx
==
self
.
num_layers
-
2
:
torch
.
nn
.
init
.
normal_
(
lin
.
weight
,
mean
=
math
.
pi
**
0.5
/
dims
[
layer_idx
]
**
0.5
,
std
=
0.0001
,
)
torch
.
nn
.
init
.
constant_
(
lin
.
bias
,
-
self
.
bias
)
elif
self
.
n_harmonic_functions_xyz
>
0
and
layer_idx
==
0
:
torch
.
nn
.
init
.
constant_
(
lin
.
bias
,
0.0
)
torch
.
nn
.
init
.
constant_
(
lin
.
weight
[:,
3
:],
0.0
)
torch
.
nn
.
init
.
normal_
(
lin
.
weight
[:,
:
3
],
0.0
,
2
**
0.5
/
out_dim
**
0.5
)
elif
self
.
n_harmonic_functions_xyz
>
0
and
layer_idx
in
self
.
skip_in
:
torch
.
nn
.
init
.
constant_
(
lin
.
bias
,
0.0
)
torch
.
nn
.
init
.
normal_
(
lin
.
weight
,
0.0
,
2
**
0.5
/
out_dim
**
0.5
)
torch
.
nn
.
init
.
constant_
(
lin
.
weight
[:,
-
(
dims
[
0
]
-
3
)
:],
0.0
)
else
:
torch
.
nn
.
init
.
constant_
(
lin
.
bias
,
0.0
)
torch
.
nn
.
init
.
normal_
(
lin
.
weight
,
0.0
,
2
**
0.5
/
out_dim
**
0.5
)
if
self
.
weight_norm
:
lin
=
nn
.
utils
.
weight_norm
(
lin
)
layers
.
append
(
lin
)
self
.
linear_layers
=
torch
.
nn
.
ModuleList
(
layers
)
self
.
out_dim
=
out_dim
self
.
softplus
=
nn
.
Softplus
(
beta
=
100
)
# pyre-fixme[14]: `forward` overrides method defined in `ImplicitFunctionBase`
# inconsistently.
def
forward
(
self
,
# ray_bundle: RayBundle,
rays_points_world
:
torch
.
Tensor
,
# TODO: unify the APIs
fun_viewpool
=
None
,
global_code
=
None
,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
if
rays_points_world
.
numel
()
==
0
or
(
self
.
embed_fn
is
None
and
fun_viewpool
is
None
and
global_code
is
None
):
return
torch
.
tensor
(
[],
device
=
rays_points_world
.
device
,
dtype
=
rays_points_world
.
dtype
).
view
(
0
,
self
.
out_dim
)
embedding
=
None
if
self
.
embed_fn
is
not
None
:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
embedding
=
self
.
embed_fn
(
rays_points_world
)
if
fun_viewpool
is
not
None
:
assert
rays_points_world
.
ndim
==
2
pooled_feature
=
fun_viewpool
(
rays_points_world
[
None
])
# TODO: pooled features are 4D!
embedding
=
torch
.
cat
((
embedding
,
pooled_feature
),
dim
=-
1
)
if
global_code
is
not
None
:
assert
embedding
.
ndim
==
2
assert
global_code
.
shape
[
0
]
==
1
# TODO: generalize to batches!
# This will require changing raytracer code
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
embedding
=
torch
.
cat
(
(
embedding
,
global_code
[
0
,
None
,
:].
expand
(
*
embedding
.
shape
[:
-
1
],
-
1
)),
dim
=-
1
,
)
x
=
embedding
for
layer_idx
in
range
(
self
.
num_layers
-
1
):
if
layer_idx
in
self
.
skip_in
:
x
=
torch
.
cat
([
x
,
embedding
],
dim
=-
1
)
/
2
**
0.5
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x
=
self
.
linear_layers
[
layer_idx
](
x
)
if
layer_idx
<
self
.
num_layers
-
2
:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x
=
self
.
softplus
(
x
)
return
x
# TODO: unify the APIs
pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
field
from
typing
import
List
,
Optional
import
torch
from
pytorch3d.common.linear_with_repeat
import
LinearWithRepeat
from
pytorch3d.implicitron.tools.config
import
registry
from
pytorch3d.renderer
import
RayBundle
,
ray_bundle_to_ray_points
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.renderer.implicit
import
HarmonicEmbedding
from
.base
import
ImplicitFunctionBase
from
.utils
import
create_embeddings_for_implicit_function
class
NeuralRadianceFieldBase
(
ImplicitFunctionBase
,
torch
.
nn
.
Module
):
n_harmonic_functions_xyz
:
int
=
10
n_harmonic_functions_dir
:
int
=
4
n_hidden_neurons_dir
:
int
=
128
latent_dim
:
int
=
0
input_xyz
:
bool
=
True
xyz_ray_dir_in_camera_coords
:
bool
=
False
color_dim
:
int
=
3
"""
Args:
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
"""
def
__post_init__
(
self
):
super
().
__init__
()
# The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for
# processing with a deep neural network.
self
.
harmonic_embedding_xyz
=
HarmonicEmbedding
(
self
.
n_harmonic_functions_xyz
,
append_input
=
True
)
self
.
harmonic_embedding_dir
=
HarmonicEmbedding
(
self
.
n_harmonic_functions_dir
,
append_input
=
True
)
if
not
self
.
input_xyz
and
self
.
latent_dim
<=
0
:
raise
ValueError
(
"The latent dimension has to be > 0 if xyz is not input!"
)
embedding_dim_dir
=
self
.
harmonic_embedding_dir
.
get_output_dim
()
self
.
xyz_encoder
=
self
.
_construct_xyz_encoder
(
input_dim
=
self
.
get_xyz_embedding_dim
()
)
self
.
intermediate_linear
=
torch
.
nn
.
Linear
(
self
.
n_hidden_neurons_xyz
,
self
.
n_hidden_neurons_xyz
)
_xavier_init
(
self
.
intermediate_linear
)
self
.
density_layer
=
torch
.
nn
.
Linear
(
self
.
n_hidden_neurons_xyz
,
1
)
_xavier_init
(
self
.
density_layer
)
# Zero the bias of the density layer to avoid
# a completely transparent initialization.
self
.
density_layer
.
bias
.
data
[:]
=
0.0
# fixme: Sometimes this is not enough
self
.
color_layer
=
torch
.
nn
.
Sequential
(
LinearWithRepeat
(
self
.
n_hidden_neurons_xyz
+
embedding_dim_dir
,
self
.
n_hidden_neurons_dir
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Linear
(
self
.
n_hidden_neurons_dir
,
self
.
color_dim
),
torch
.
nn
.
Sigmoid
(),
)
def
get_xyz_embedding_dim
(
self
):
return
(
self
.
harmonic_embedding_xyz
.
get_output_dim
()
*
int
(
self
.
input_xyz
)
+
self
.
latent_dim
)
def
_construct_xyz_encoder
(
self
,
input_dim
:
int
):
raise
NotImplementedError
()
def
_get_colors
(
self
,
features
:
torch
.
Tensor
,
rays_directions
:
torch
.
Tensor
):
"""
This function takes per-point `features` predicted by `self.xyz_encoder`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed
=
torch
.
nn
.
functional
.
normalize
(
rays_directions
,
dim
=-
1
)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding
=
self
.
harmonic_embedding_dir
(
rays_directions_normed
)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return
self
.
color_layer
((
self
.
intermediate_linear
(
features
),
rays_embedding
))
@
staticmethod
def
allows_multiple_passes
()
->
bool
:
"""
Returns True as this implicit function allows
multiple passes. Overridden from ImplicitFunctionBase.
"""
return
True
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
**
kwargs
,
):
"""
The forward function accepts the parametrizations of
3D points sampled along projection rays. The forward
pass is responsible for attaching a 3D vector
and a 1D scalar representing the point's
RGB color and opacity respectively.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world
=
ray_bundle_to_ray_points
(
ray_bundle
)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
embeds
=
create_embeddings_for_implicit_function
(
xyz_world
=
ray_bundle_to_ray_points
(
ray_bundle
),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function
=
self
.
harmonic_embedding_xyz
if
self
.
input_xyz
else
None
,
global_code
=
global_code
,
fun_viewpool
=
fun_viewpool
,
xyz_in_camera_coords
=
self
.
xyz_ray_dir_in_camera_coords
,
camera
=
camera
,
)
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features
=
self
.
xyz_encoder
(
embeds
)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features
=
features
.
reshape
(
*
rays_points_world
.
shape
[:
-
1
],
-
1
)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities
=
self
.
density_layer
(
features
)
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
if
self
.
xyz_ray_dir_in_camera_coords
:
if
camera
is
None
:
raise
ValueError
(
"Camera must be given if xyz_ray_dir_in_camera_coords"
)
directions
=
ray_bundle
.
directions
@
camera
.
R
else
:
directions
=
ray_bundle
.
directions
rays_colors
=
self
.
_get_colors
(
features
,
directions
)
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
return
raw_densities
,
rays_colors
,
{}
@
registry
.
register
class
NeuralRadianceFieldImplicitFunction
(
NeuralRadianceFieldBase
):
transformer_dim_down_factor
:
float
=
1.0
n_hidden_neurons_xyz
:
int
=
256
n_layers_xyz
:
int
=
8
append_xyz
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
5
])
def
_construct_xyz_encoder
(
self
,
input_dim
:
int
):
return
MLPWithInputSkips
(
self
.
n_layers_xyz
,
input_dim
,
self
.
n_hidden_neurons_xyz
,
input_dim
,
self
.
n_hidden_neurons_xyz
,
input_skips
=
self
.
append_xyz
,
)
@
registry
.
register
class
NeRFormerImplicitFunction
(
NeuralRadianceFieldBase
):
transformer_dim_down_factor
:
float
=
2.0
n_hidden_neurons_xyz
:
int
=
80
n_layers_xyz
:
int
=
2
append_xyz
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
1
])
def
_construct_xyz_encoder
(
self
,
input_dim
:
int
):
return
TransformerWithInputSkips
(
self
.
n_layers_xyz
,
input_dim
,
self
.
n_hidden_neurons_xyz
,
input_dim
,
self
.
n_hidden_neurons_xyz
,
input_skips
=
self
.
append_xyz
,
dim_down_factor
=
self
.
transformer_dim_down_factor
,
)
@
staticmethod
def
requires_pooling_without_aggregation
()
->
bool
:
"""
Returns True as this implicit function needs
pooling without aggregation. Overridden from ImplicitFunctionBase.
"""
return
True
class
MLPWithInputSkips
(
torch
.
nn
.
Module
):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
of a sequence of linear layers with ReLU activations.
Additionally, for a set of predefined layers `input_skips`, the forward pass
appends a skip tensor `z` to the output of the preceding layer.
Note that this follows the architecture described in the Supplementary
Material (Fig. 7) of [1].
References:
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
"""
def
_make_affine_layer
(
self
,
input_dim
,
hidden_dim
):
l1
=
torch
.
nn
.
Linear
(
input_dim
,
hidden_dim
*
2
)
l2
=
torch
.
nn
.
Linear
(
hidden_dim
*
2
,
hidden_dim
*
2
)
_xavier_init
(
l1
)
_xavier_init
(
l2
)
return
torch
.
nn
.
Sequential
(
l1
,
torch
.
nn
.
ReLU
(
True
),
l2
)
def
_apply_affine_layer
(
self
,
layer
,
x
,
z
):
mu_log_std
=
layer
(
z
)
mu
,
log_std
=
mu_log_std
.
split
(
mu_log_std
.
shape
[
-
1
]
//
2
,
dim
=-
1
)
std
=
torch
.
nn
.
functional
.
softplus
(
log_std
)
return
(
x
-
mu
)
*
std
def
__init__
(
self
,
n_layers
:
int
=
8
,
input_dim
:
int
=
39
,
output_dim
:
int
=
256
,
skip_dim
:
int
=
39
,
hidden_dim
:
int
=
256
,
input_skips
:
List
[
int
]
=
[
5
],
skip_affine_trans
:
bool
=
False
,
no_last_relu
=
False
,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super
().
__init__
()
layers
=
[]
skip_affine_layers
=
[]
for
layeri
in
range
(
n_layers
):
dimin
=
hidden_dim
if
layeri
>
0
else
input_dim
dimout
=
hidden_dim
if
layeri
+
1
<
n_layers
else
output_dim
if
layeri
>
0
and
layeri
in
input_skips
:
if
skip_affine_trans
:
skip_affine_layers
.
append
(
self
.
_make_affine_layer
(
skip_dim
,
hidden_dim
)
)
else
:
dimin
=
hidden_dim
+
skip_dim
linear
=
torch
.
nn
.
Linear
(
dimin
,
dimout
)
_xavier_init
(
linear
)
layers
.
append
(
torch
.
nn
.
Sequential
(
linear
,
torch
.
nn
.
ReLU
(
True
))
if
not
no_last_relu
or
layeri
+
1
<
n_layers
else
linear
)
self
.
mlp
=
torch
.
nn
.
ModuleList
(
layers
)
if
skip_affine_trans
:
self
.
skip_affines
=
torch
.
nn
.
ModuleList
(
skip_affine_layers
)
self
.
_input_skips
=
set
(
input_skips
)
self
.
_skip_affine_trans
=
skip_affine_trans
def
forward
(
self
,
x
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Args:
x: The input tensor of shape `(..., input_dim)`.
z: The input skip tensor of shape `(..., skip_dim)` which is appended
to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape `(..., output_dim)`.
"""
y
=
x
if
z
is
None
:
# if the skip tensor is None, we use `x` instead.
z
=
x
skipi
=
0
for
li
,
layer
in
enumerate
(
self
.
mlp
):
if
li
in
self
.
_input_skips
:
if
self
.
_skip_affine_trans
:
y
=
self
.
_apply_affine_layer
(
self
.
skip_affines
[
skipi
],
y
,
z
)
else
:
y
=
torch
.
cat
((
y
,
z
),
dim
=-
1
)
skipi
+=
1
y
=
layer
(
y
)
return
y
class
TransformerWithInputSkips
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_layers
:
int
=
8
,
input_dim
:
int
=
39
,
output_dim
:
int
=
256
,
skip_dim
:
int
=
39
,
hidden_dim
:
int
=
64
,
input_skips
:
List
[
int
]
=
[
5
],
dim_down_factor
:
float
=
1
,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super
().
__init__
()
self
.
first
=
torch
.
nn
.
Linear
(
input_dim
,
hidden_dim
)
_xavier_init
(
self
.
first
)
self
.
skip_linear
=
torch
.
nn
.
ModuleList
()
layers_pool
,
layers_ray
=
[],
[]
dimout
=
0
for
layeri
in
range
(
n_layers
):
dimin
=
int
(
round
(
hidden_dim
/
(
dim_down_factor
**
layeri
)))
dimout
=
int
(
round
(
hidden_dim
/
(
dim_down_factor
**
(
layeri
+
1
))))
print
(
f
"Tr:
{
dimin
}
->
{
dimout
}
"
)
for
_i
,
l
in
enumerate
((
layers_pool
,
layers_ray
)):
l
.
append
(
TransformerEncoderLayer
(
d_model
=
[
dimin
,
dimout
][
_i
],
nhead
=
4
,
dim_feedforward
=
hidden_dim
,
dropout
=
0.0
,
d_model_out
=
dimout
,
)
)
if
layeri
in
input_skips
:
self
.
skip_linear
.
append
(
torch
.
nn
.
Linear
(
input_dim
,
dimin
))
self
.
last
=
torch
.
nn
.
Linear
(
dimout
,
output_dim
)
_xavier_init
(
self
.
last
)
self
.
layers_pool
,
self
.
layers_ray
=
(
torch
.
nn
.
ModuleList
(
layers_pool
),
torch
.
nn
.
ModuleList
(
layers_ray
),
)
self
.
_input_skips
=
set
(
input_skips
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Args:
x: The input tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
z: The input skip tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
which is appended to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
"""
if
z
is
None
:
# if the skip tensor is None, we use `x` instead.
z
=
x
y
=
self
.
first
(
x
)
B
,
n_pool
,
n_rays
,
n_pts
,
dim
=
y
.
shape
# y_p in n_pool, n_pts, B x n_rays x dim
y_p
=
y
.
permute
(
1
,
3
,
0
,
2
,
4
)
skipi
=
0
dimh
=
dim
for
li
,
(
layer_pool
,
layer_ray
)
in
enumerate
(
zip
(
self
.
layers_pool
,
self
.
layers_ray
)
):
y_pool_attn
=
y_p
.
reshape
(
n_pool
,
n_pts
*
B
*
n_rays
,
dimh
)
if
li
in
self
.
_input_skips
:
z_skip
=
self
.
skip_linear
[
skipi
](
z
)
y_pool_attn
=
y_pool_attn
+
z_skip
.
permute
(
1
,
3
,
0
,
2
,
4
).
reshape
(
n_pool
,
n_pts
*
B
*
n_rays
,
dimh
)
skipi
+=
1
# n_pool x B*n_rays*n_pts x dim
y_pool_attn
,
pool_attn
=
layer_pool
(
y_pool_attn
,
src_key_padding_mask
=
None
)
dimh
=
y_pool_attn
.
shape
[
-
1
]
y_ray_attn
=
(
y_pool_attn
.
view
(
n_pool
,
n_pts
,
B
*
n_rays
,
dimh
)
.
permute
(
1
,
0
,
2
,
3
)
.
reshape
(
n_pts
,
n_pool
*
B
*
n_rays
,
dimh
)
)
# n_pts x n_pool*B*n_rays x dim
y_ray_attn
,
ray_attn
=
layer_ray
(
y_ray_attn
,
src_key_padding_mask
=
None
,
)
y_p
=
y_ray_attn
.
view
(
n_pts
,
n_pool
,
B
*
n_rays
,
dimh
).
permute
(
1
,
0
,
2
,
3
)
y
=
y_p
.
view
(
n_pool
,
n_pts
,
B
,
n_rays
,
dimh
).
permute
(
2
,
0
,
3
,
1
,
4
)
W
=
torch
.
softmax
(
y
[...,
:
1
],
dim
=
1
)
y
=
(
y
*
W
).
sum
(
dim
=
1
)
y
=
self
.
last
(
y
)
return
y
class
TransformerEncoderLayer
(
torch
.
nn
.
Module
):
r
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def
__init__
(
self
,
d_model
,
nhead
,
dim_feedforward
=
2048
,
dropout
=
0.1
,
d_model_out
=-
1
):
super
(
TransformerEncoderLayer
,
self
).
__init__
()
self
.
self_attn
=
torch
.
nn
.
MultiheadAttention
(
d_model
,
nhead
,
dropout
=
dropout
)
# Implementation of Feedforward model
self
.
linear1
=
torch
.
nn
.
Linear
(
d_model
,
dim_feedforward
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
d_model_out
=
d_model
if
d_model_out
<=
0
else
d_model_out
self
.
linear2
=
torch
.
nn
.
Linear
(
dim_feedforward
,
d_model_out
)
self
.
norm1
=
torch
.
nn
.
LayerNorm
(
d_model
)
self
.
norm2
=
torch
.
nn
.
LayerNorm
(
d_model_out
)
self
.
dropout1
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
activation
=
torch
.
nn
.
functional
.
relu
def
forward
(
self
,
src
,
src_mask
=
None
,
src_key_padding_mask
=
None
):
r
"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2
,
attn
=
self
.
self_attn
(
src
,
src
,
src
,
attn_mask
=
src_mask
,
key_padding_mask
=
src_key_padding_mask
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
src2
=
self
.
linear2
(
self
.
dropout
(
self
.
activation
(
self
.
linear1
(
src
))))
d_out
=
src2
.
shape
[
-
1
]
src
=
src
[...,
:
d_out
]
+
self
.
dropout2
(
src2
)[...,
:
d_out
]
src
=
self
.
norm2
(
src
)
return
src
,
attn
def
_xavier_init
(
linear
)
->
None
:
"""
Performs the Xavier weight initialization of the linear layer `linear`.
"""
torch
.
nn
.
init
.
xavier_uniform_
(
linear
.
weight
.
data
)
pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py
0 → 100644
View file @
cdd2142d
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/vsitzmann/scene-representation-networks
# Copyright (c) 2019 Vincent Sitzmann
from
typing
import
Any
,
Optional
,
Tuple
,
cast
import
torch
from
pytorch3d.common.linear_with_repeat
import
LinearWithRepeat
from
pytorch3d.implicitron.third_party
import
hyperlayers
,
pytorch_prototyping
from
pytorch3d.implicitron.tools.config
import
Configurable
,
registry
,
run_auto_creation
from
pytorch3d.renderer
import
RayBundle
,
ray_bundle_to_ray_points
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.renderer.implicit
import
HarmonicEmbedding
from
.base
import
ImplicitFunctionBase
from
.utils
import
create_embeddings_for_implicit_function
def
_kaiming_normal_init
(
module
:
torch
.
nn
.
Module
)
->
None
:
if
isinstance
(
module
,
(
torch
.
nn
.
Linear
,
LinearWithRepeat
)):
torch
.
nn
.
init
.
kaiming_normal_
(
module
.
weight
,
a
=
0.0
,
nonlinearity
=
"relu"
,
mode
=
"fan_in"
)
class
SRNRaymarchFunction
(
Configurable
,
torch
.
nn
.
Module
):
n_harmonic_functions
:
int
=
3
# 0 means raw 3D coord inputs
n_hidden_units
:
int
=
256
n_layers
:
int
=
2
in_features
:
int
=
3
out_features
:
int
=
256
latent_dim
:
int
=
0
xyz_in_camera_coords
:
bool
=
False
# The internal network can be set as an output of an SRNHyperNet.
# Note that, in order to avoid Pytorch's automatic registering of the
# raymarch_function module on construction, we input the network wrapped
# as a 1-tuple.
# raymarch_function should ideally be typed as Optional[Tuple[Callable]]
# but Omegaconf.structured doesn't like that. TODO: revisit after new
# release of omegaconf including https://github.com/omry/omegaconf/pull/749 .
raymarch_function
:
Any
=
None
def
__post_init__
(
self
):
super
().
__init__
()
self
.
_harmonic_embedding
=
HarmonicEmbedding
(
self
.
n_harmonic_functions
,
append_input
=
True
)
input_embedding_dim
=
(
HarmonicEmbedding
.
get_output_dim_static
(
self
.
in_features
,
self
.
n_harmonic_functions
,
True
,
)
+
self
.
latent_dim
)
if
self
.
raymarch_function
is
not
None
:
self
.
_net
=
self
.
raymarch_function
[
0
]
else
:
self
.
_net
=
pytorch_prototyping
.
FCBlock
(
hidden_ch
=
self
.
n_hidden_units
,
num_hidden_layers
=
self
.
n_layers
,
in_features
=
input_embedding_dim
,
out_features
=
self
.
out_features
,
)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
**
kwargs
,
):
"""
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: Set to None.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world
=
ray_bundle_to_ray_points
(
ray_bundle
)
embeds
=
create_embeddings_for_implicit_function
(
xyz_world
=
ray_bundle_to_ray_points
(
ray_bundle
),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
xyz_embedding_function
=
self
.
_harmonic_embedding
,
global_code
=
global_code
,
fun_viewpool
=
fun_viewpool
,
xyz_in_camera_coords
=
self
.
xyz_in_camera_coords
,
camera
=
camera
,
)
# Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raymarch_features
=
self
.
_net
(
embeds
.
view
(
embeds
.
shape
[
0
],
-
1
,
embeds
.
shape
[
-
1
])
)
# raymarch_features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
raymarch_features
=
raymarch_features
.
reshape
(
*
rays_points_world
.
shape
[:
-
1
],
-
1
)
return
raymarch_features
,
None
class
SRNPixelGenerator
(
Configurable
,
torch
.
nn
.
Module
):
n_harmonic_functions
:
int
=
4
n_hidden_units
:
int
=
256
n_hidden_units_color
:
int
=
128
n_layers
:
int
=
2
in_features
:
int
=
256
out_features
:
int
=
3
ray_dir_in_camera_coords
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
self
.
_harmonic_embedding
=
HarmonicEmbedding
(
self
.
n_harmonic_functions
,
append_input
=
True
)
self
.
_net
=
pytorch_prototyping
.
FCBlock
(
hidden_ch
=
self
.
n_hidden_units
,
num_hidden_layers
=
self
.
n_layers
,
in_features
=
self
.
in_features
,
out_features
=
self
.
n_hidden_units
,
)
self
.
_density_layer
=
torch
.
nn
.
Linear
(
self
.
n_hidden_units
,
1
)
self
.
_density_layer
.
apply
(
_kaiming_normal_init
)
embedding_dim_dir
=
self
.
_harmonic_embedding
.
get_output_dim
(
input_dims
=
3
)
self
.
_color_layer
=
torch
.
nn
.
Sequential
(
LinearWithRepeat
(
self
.
n_hidden_units
+
embedding_dim_dir
,
self
.
n_hidden_units_color
,
),
torch
.
nn
.
LayerNorm
([
self
.
n_hidden_units_color
]),
torch
.
nn
.
ReLU
(
inplace
=
True
),
torch
.
nn
.
Linear
(
self
.
n_hidden_units_color
,
self
.
out_features
),
)
self
.
_color_layer
.
apply
(
_kaiming_normal_init
)
# TODO: merge with NeuralRadianceFieldBase's _get_colors
def
_get_colors
(
self
,
features
:
torch
.
Tensor
,
rays_directions
:
torch
.
Tensor
):
"""
This function takes per-point `features` predicted by `self.net`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed
=
torch
.
nn
.
functional
.
normalize
(
rays_directions
,
dim
=-
1
)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding
=
self
.
_harmonic_embedding
(
rays_directions_normed
)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return
self
.
_color_layer
((
features
,
rays_embedding
))
def
forward
(
self
,
raymarch_features
:
torch
.
Tensor
,
ray_bundle
:
RayBundle
,
camera
:
Optional
[
CamerasBase
]
=
None
,
**
kwargs
,
):
"""
Args:
raymarch_features: Features from the raymarching network of shape
`(minibatch, ..., self.in_features)`
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features
=
self
.
_net
(
raymarch_features
)
# features.shape = [minibatch x ... x self.n_hidden_units]
if
self
.
ray_dir_in_camera_coords
:
if
camera
is
None
:
raise
ValueError
(
"Camera must be given if xyz_ray_dir_in_camera_coords"
)
directions
=
ray_bundle
.
directions
@
camera
.
R
else
:
directions
=
ray_bundle
.
directions
# NNs operate on the flattenned rays; reshaping to the correct spatial size
features
=
features
.
reshape
(
*
raymarch_features
.
shape
[:
-
1
],
-
1
)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities
=
self
.
_density_layer
(
features
)
rays_colors
=
self
.
_get_colors
(
features
,
directions
)
return
raw_densities
,
rays_colors
class
SRNRaymarchHyperNet
(
Configurable
,
torch
.
nn
.
Module
):
"""
This is a raymarching function which has a forward like SRNRaymarchFunction
but instead of the weights being parameters of the module, they
are the output of another network, the hypernet, which takes the global_code
as input. All the dataclass members of SRNRaymarchFunction are here with the
same meaning. In addition, there are members with names ending `_hypernet`
which affect the hypernet.
Because this class may be called repeatedly for the same global_code, the
output of the hypernet is cached in self.cached_srn_raymarch_function.
This member must be manually set to None whenever the global_code changes.
"""
n_harmonic_functions
:
int
=
3
# 0 means raw 3D coord inputs
n_hidden_units
:
int
=
256
n_layers
:
int
=
2
n_hidden_units_hypernet
:
int
=
256
n_layers_hypernet
:
int
=
1
in_features
:
int
=
3
out_features
:
int
=
256
latent_dim_hypernet
:
int
=
0
latent_dim
:
int
=
0
xyz_in_camera_coords
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
raymarch_input_embedding_dim
=
(
HarmonicEmbedding
.
get_output_dim_static
(
self
.
in_features
,
self
.
n_harmonic_functions
,
True
,
)
+
self
.
latent_dim
)
self
.
_hypernet
=
hyperlayers
.
HyperFC
(
hyper_in_ch
=
self
.
latent_dim_hypernet
,
hyper_num_hidden_layers
=
self
.
n_layers_hypernet
,
hyper_hidden_ch
=
self
.
n_hidden_units_hypernet
,
hidden_ch
=
self
.
n_hidden_units
,
num_hidden_layers
=
self
.
n_layers
,
in_ch
=
raymarch_input_embedding_dim
,
out_ch
=
self
.
n_hidden_units
,
)
self
.
cached_srn_raymarch_function
:
Optional
[
Tuple
[
SRNRaymarchFunction
]]
=
None
def
_run_hypernet
(
self
,
global_code
:
torch
.
Tensor
)
->
Tuple
[
SRNRaymarchFunction
]:
"""
Runs the hypernet and returns a 1-tuple containing the generated
srn_raymarch_function.
"""
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
net
=
self
.
_hypernet
(
global_code
)
# use the hyper-net generated network to instantiate the raymarch module
srn_raymarch_function
=
SRNRaymarchFunction
(
n_harmonic_functions
=
self
.
n_harmonic_functions
,
n_hidden_units
=
self
.
n_hidden_units
,
n_layers
=
self
.
n_layers
,
in_features
=
self
.
in_features
,
out_features
=
self
.
out_features
,
latent_dim
=
self
.
latent_dim
,
xyz_in_camera_coords
=
self
.
xyz_in_camera_coords
,
raymarch_function
=
(
net
,),
)
# move the generated raymarch function to the correct device
srn_raymarch_function
.
to
(
global_code
.
device
)
return
(
srn_raymarch_function
,)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
**
kwargs
,
):
if
global_code
is
None
:
raise
ValueError
(
"SRN Hypernetwork requires a non-trivial global code."
)
# The raymarching network is cached in case the function is called repeatedly
# across LSTM iterations for the same global_code.
if
self
.
cached_srn_raymarch_function
is
None
:
# generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
self
.
cached_srn_raymarch_function
=
self
.
_run_hypernet
(
global_code
)
(
srn_raymarch_function
,)
=
cast
(
Tuple
[
SRNRaymarchFunction
],
self
.
cached_srn_raymarch_function
)
return
srn_raymarch_function
(
ray_bundle
=
ray_bundle
,
fun_viewpool
=
fun_viewpool
,
camera
=
camera
,
global_code
=
None
,
# the hypernetwork takes the global code
)
@
registry
.
register
# pyre-fixme[13]: Uninitialized attribute
class
SRNImplicitFunction
(
ImplicitFunctionBase
,
torch
.
nn
.
Module
):
raymarch_function
:
SRNRaymarchFunction
pixel_generator
:
SRNPixelGenerator
def
__post_init__
(
self
):
super
().
__init__
()
run_auto_creation
(
self
)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
raymarch_features
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
):
predict_colors
=
raymarch_features
is
not
None
if
predict_colors
:
return
self
.
pixel_generator
(
raymarch_features
=
raymarch_features
,
ray_bundle
=
ray_bundle
,
camera
=
camera
,
**
kwargs
,
)
else
:
return
self
.
raymarch_function
(
ray_bundle
=
ray_bundle
,
fun_viewpool
=
fun_viewpool
,
camera
=
camera
,
global_code
=
global_code
,
**
kwargs
,
)
@
registry
.
register
# pyre-fixme[13]: Uninitialized attribute
class
SRNHyperNetImplicitFunction
(
ImplicitFunctionBase
,
torch
.
nn
.
Module
):
"""
This implicit function uses a hypernetwork to generate the
SRNRaymarchingFunction, and this is cached. Whenever the
global_code changes, `on_bind_args` must be called to clear
the cache.
"""
hypernet
:
SRNRaymarchHyperNet
pixel_generator
:
SRNPixelGenerator
def
__post_init__
(
self
):
super
().
__init__
()
run_auto_creation
(
self
)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
fun_viewpool
=
None
,
camera
:
Optional
[
CamerasBase
]
=
None
,
global_code
=
None
,
raymarch_features
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
):
predict_colors
=
raymarch_features
is
not
None
if
predict_colors
:
return
self
.
pixel_generator
(
raymarch_features
=
raymarch_features
,
ray_bundle
=
ray_bundle
,
camera
=
camera
,
**
kwargs
,
)
else
:
return
self
.
hypernet
(
ray_bundle
=
ray_bundle
,
fun_viewpool
=
fun_viewpool
,
camera
=
camera
,
global_code
=
global_code
,
**
kwargs
,
)
def
on_bind_args
(
self
):
"""
The global_code may have changed, so we reset the hypernet.
"""
self
.
hypernet
.
cached_srn_raymarch_function
=
None
pytorch3d/implicitron/models/implicit_function/utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Callable
,
Optional
import
torch
from
pytorch3d.renderer.cameras
import
CamerasBase
def
broadcast_global_code
(
embeds
:
torch
.
Tensor
,
global_code
:
torch
.
Tensor
):
"""
Expands the `global_code` of shape (minibatch, dim)
so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
and appends to the last dimension of `embeds`.
"""
bs
=
embeds
.
shape
[
0
]
global_code_broadcast
=
global_code
.
view
(
bs
,
*
([
1
]
*
(
embeds
.
ndim
-
2
)),
-
1
).
expand
(
*
embeds
.
shape
[:
-
1
],
global_code
.
shape
[
-
1
],
)
return
torch
.
cat
([
embeds
,
global_code_broadcast
],
dim
=-
1
)
def
create_embeddings_for_implicit_function
(
xyz_world
:
torch
.
Tensor
,
xyz_in_camera_coords
:
bool
,
global_code
:
Optional
[
torch
.
Tensor
],
camera
:
Optional
[
CamerasBase
],
fun_viewpool
:
Optional
[
Callable
],
xyz_embedding_function
:
Optional
[
Callable
],
)
->
torch
.
Tensor
:
bs
,
*
spatial_size
,
pts_per_ray
,
_
=
xyz_world
.
shape
if
xyz_in_camera_coords
:
if
camera
is
None
:
raise
ValueError
(
"Camera must be given if xyz_in_camera_coords"
)
ray_points_for_embed
=
(
camera
.
get_world_to_view_transform
()
.
transform_points
(
xyz_world
.
view
(
bs
,
-
1
,
3
))
.
view
(
xyz_world
.
shape
)
)
else
:
ray_points_for_embed
=
xyz_world
if
xyz_embedding_function
is
None
:
embeds
=
torch
.
empty
(
bs
,
1
,
math
.
prod
(
spatial_size
),
pts_per_ray
,
0
,
dtype
=
xyz_world
.
dtype
,
device
=
xyz_world
.
device
,
)
else
:
embeds
=
xyz_embedding_function
(
ray_points_for_embed
).
reshape
(
bs
,
1
,
math
.
prod
(
spatial_size
),
pts_per_ray
,
-
1
,
)
# flatten spatial, add n_src dim
if
fun_viewpool
is
not
None
:
# viewpooling
embeds_viewpooled
=
fun_viewpool
(
xyz_world
.
reshape
(
bs
,
-
1
,
3
))
embed_shape
=
(
bs
,
embeds_viewpooled
.
shape
[
1
],
math
.
prod
(
spatial_size
),
pts_per_ray
,
-
1
,
)
embeds_viewpooled
=
embeds_viewpooled
.
reshape
(
*
embed_shape
)
if
embeds
is
not
None
:
embeds
=
torch
.
cat
([
embeds
.
expand
(
*
embed_shape
),
embeds_viewpooled
],
dim
=-
1
)
else
:
embeds
=
embeds_viewpooled
if
global_code
is
not
None
:
# append the broadcasted global code to embeds
embeds
=
broadcast_global_code
(
embeds
,
global_code
)
return
embeds
pytorch3d/implicitron/models/metrics.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
warnings
from
typing
import
Dict
,
Optional
import
torch
from
pytorch3d.implicitron.tools
import
metric_utils
as
utils
from
pytorch3d.renderer
import
utils
as
rend_utils
class
ViewMetrics
(
torch
.
nn
.
Module
):
def
forward
(
self
,
image_sampling_grid
:
torch
.
Tensor
,
images
:
Optional
[
torch
.
Tensor
]
=
None
,
images_pred
:
Optional
[
torch
.
Tensor
]
=
None
,
depths
:
Optional
[
torch
.
Tensor
]
=
None
,
depths_pred
:
Optional
[
torch
.
Tensor
]
=
None
,
masks
:
Optional
[
torch
.
Tensor
]
=
None
,
masks_pred
:
Optional
[
torch
.
Tensor
]
=
None
,
masks_crop
:
Optional
[
torch
.
Tensor
]
=
None
,
grad_theta
:
Optional
[
torch
.
Tensor
]
=
None
,
density_grid
:
Optional
[
torch
.
Tensor
]
=
None
,
keys_prefix
:
str
=
"loss_"
,
mask_renders_by_pred
:
bool
=
False
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Calculates various differentiable metrics useful for supervising
differentiable rendering pipelines.
Args:
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
image locations at which the predictions are defined.
All ground truth inputs are sampled at these
locations in order to extract values that correspond
to the predictions.
images: A tensor of shape `(B, H, W, 3)` containing ground truth
rgb values.
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
rgb values.
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
depth values.
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
depth values.
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
foreground masks.
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
of a gradient of a signed distance function w.r.t.
input 3D coordinates used to compute the eikonal loss.
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
`Hg x Wg x Dg` voxel grid of density values.
keys_prefix: A common prefix for all keys in the output dictionary
containing all metrics.
mask_renders_by_pred: If `True`, masks rendered images by the predicted
`masks_pred` prior to computing all rgb metrics.
Returns:
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
The calculated metrics are:
rgb_huber: A robust huber loss between `image_pred` and `image`.
rgb_mse: Mean squared error between `image_pred` and `image`.
rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`.
rgb_psnr_fg: Peak signal-to-noise ratio between the foreground
region of `image_pred` and `image` as defined by `mask`.
rgb_mse_fg: Mean squared error between the foreground
region of `image_pred` and `image` as defined by `mask`.
mask_neg_iou: (1 - intersection-over-union) between `mask_pred`
and `mask`.
mask_bce: Binary cross entropy between `mask_pred` and `mask`.
mask_beta_prior: A loss enforcing strictly binary values
of `mask_pred`: `log(mask_pred) + log(1-mask_pred)`
depth_abs: Mean per-pixel L1 distance between
`depth_pred` and `depth`.
depth_abs_fg: Mean per-pixel L1 distance between the foreground
region of `depth_pred` and `depth` as defined by `mask`.
eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`.
density_tv: The Total Variation regularizer of density
values in `density_grid` (sum of L1 distances of values
of all 4-neighbouring cells).
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
predicted depth values.
"""
# TODO: extract functions
# reshape from B x ... x DIM to B x DIM x -1 x 1
images_pred
,
masks_pred
,
depths_pred
=
[
_reshape_nongrid_var
(
x
)
for
x
in
[
images_pred
,
masks_pred
,
depths_pred
]
]
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
image_sampling_grid
=
image_sampling_grid
.
reshape
(
image_sampling_grid
.
shape
[
0
],
-
1
,
1
,
2
)
# closure with the given image_sampling_grid
def
sample
(
tensor
,
mode
):
if
tensor
is
None
:
return
tensor
return
rend_utils
.
ndc_grid_sample
(
tensor
,
image_sampling_grid
,
mode
=
mode
)
# eval all results in this size
images
=
sample
(
images
,
mode
=
"bilinear"
)
depths
=
sample
(
depths
,
mode
=
"nearest"
)
masks
=
sample
(
masks
,
mode
=
"nearest"
)
masks_crop
=
sample
(
masks_crop
,
mode
=
"nearest"
)
if
masks_crop
is
None
and
images_pred
is
not
None
:
masks_crop
=
torch
.
ones_like
(
images_pred
[:,
:
1
])
if
masks_crop
is
None
and
depths_pred
is
not
None
:
masks_crop
=
torch
.
ones_like
(
depths_pred
[:,
:
1
])
preds
=
{}
if
images
is
not
None
and
images_pred
is
not
None
:
# TODO: mask_renders_by_pred is always false; simplify
preds
.
update
(
_rgb_metrics
(
images
,
images_pred
,
masks
,
masks_pred
,
masks_crop
,
mask_renders_by_pred
,
)
)
if
masks_pred
is
not
None
:
preds
[
"mask_beta_prior"
]
=
utils
.
beta_prior
(
masks_pred
)
if
masks
is
not
None
and
masks_pred
is
not
None
:
preds
[
"mask_neg_iou"
]
=
utils
.
neg_iou_loss
(
masks_pred
,
masks
,
mask
=
masks_crop
)
preds
[
"mask_bce"
]
=
utils
.
calc_bce
(
masks_pred
,
masks
,
mask
=
masks_crop
)
if
depths
is
not
None
and
depths_pred
is
not
None
:
assert
masks_crop
is
not
None
_
,
abs_
=
utils
.
eval_depth
(
depths_pred
,
depths
,
get_best_scale
=
True
,
mask
=
masks_crop
,
crop
=
0
)
preds
[
"depth_abs"
]
=
abs_
.
mean
()
if
masks
is
not
None
:
mask
=
masks
*
masks_crop
_
,
abs_
=
utils
.
eval_depth
(
depths_pred
,
depths
,
get_best_scale
=
True
,
mask
=
mask
,
crop
=
0
)
preds
[
"depth_abs_fg"
]
=
abs_
.
mean
()
# regularizers
if
grad_theta
is
not
None
:
preds
[
"eikonal"
]
=
_get_eikonal_loss
(
grad_theta
)
if
density_grid
is
not
None
:
preds
[
"density_tv"
]
=
_get_grid_tv_loss
(
density_grid
)
if
depths_pred
is
not
None
:
preds
[
"depth_neg_penalty"
]
=
_get_depth_neg_penalty_loss
(
depths_pred
)
if
keys_prefix
is
not
None
:
preds
=
{(
keys_prefix
+
k
):
v
for
k
,
v
in
preds
.
items
()}
return
preds
def
_rgb_metrics
(
images
,
images_pred
,
masks
,
masks_pred
,
masks_crop
,
mask_renders_by_pred
):
assert
masks_crop
is
not
None
if
mask_renders_by_pred
:
images
=
images
[...,
masks_pred
.
reshape
(
-
1
),
:]
masks_crop
=
masks_crop
[...,
masks_pred
.
reshape
(
-
1
),
:]
masks
=
masks
is
not
None
and
masks
[...,
masks_pred
.
reshape
(
-
1
),
:]
rgb_squared
=
((
images_pred
-
images
)
**
2
).
mean
(
dim
=
1
,
keepdim
=
True
)
rgb_loss
=
utils
.
huber
(
rgb_squared
,
scaling
=
0.03
)
crop_mass
=
masks_crop
.
sum
().
clamp
(
1.0
)
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
preds
=
{
"rgb_huber"
:
(
rgb_loss
*
masks_crop
).
sum
()
/
crop_mass
,
"rgb_mse"
:
(
rgb_squared
*
masks_crop
).
sum
()
/
crop_mass
,
"rgb_psnr"
:
utils
.
calc_psnr
(
images_pred
,
images
,
mask
=
masks_crop
),
}
if
masks
is
not
None
:
masks
=
masks_crop
*
masks
preds
[
"rgb_psnr_fg"
]
=
utils
.
calc_psnr
(
images_pred
,
images
,
mask
=
masks
)
preds
[
"rgb_mse_fg"
]
=
(
rgb_squared
*
masks
).
sum
()
/
masks
.
sum
().
clamp
(
1.0
)
return
preds
def
_get_eikonal_loss
(
grad_theta
):
return
((
grad_theta
.
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
def
_get_grid_tv_loss
(
grid
,
log_domain
:
bool
=
True
,
eps
:
float
=
1e-5
):
if
log_domain
:
if
(
grid
<=
-
eps
).
any
():
warnings
.
warn
(
"Grid has negative values; this will produce NaN loss"
)
grid
=
torch
.
log
(
grid
+
eps
)
# this is an isotropic version, note that it ignores last rows/cols
return
torch
.
mean
(
utils
.
safe_sqrt
(
(
grid
[...,
:
-
1
,
:
-
1
,
1
:]
-
grid
[...,
:
-
1
,
:
-
1
,
:
-
1
])
**
2
+
(
grid
[...,
:
-
1
,
1
:,
:
-
1
]
-
grid
[...,
:
-
1
,
:
-
1
,
:
-
1
])
**
2
+
(
grid
[...,
1
:,
:
-
1
,
:
-
1
]
-
grid
[...,
:
-
1
,
:
-
1
,
:
-
1
])
**
2
,
eps
=
1e-5
,
)
)
def
_get_depth_neg_penalty_loss
(
depth
):
neg_penalty
=
depth
.
clamp
(
min
=
None
,
max
=
0.0
)
**
2
return
torch
.
mean
(
neg_penalty
)
def
_reshape_nongrid_var
(
x
):
if
x
is
None
:
return
None
ba
,
*
_
,
dim
=
x
.
shape
return
x
.
reshape
(
ba
,
-
1
,
1
,
dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
pytorch3d/implicitron/models/model_dbir.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Dict
,
List
import
torch
from
pytorch3d.implicitron.dataset.utils
import
is_known_frame
from
pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis
import
(
NewViewSynthesisPrediction
,
)
from
pytorch3d.implicitron.tools.point_cloud_utils
import
(
get_rgbd_point_cloud
,
render_point_cloud_pytorch3d
,
)
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.structures
import
Pointclouds
class
ModelDBIR
(
torch
.
nn
.
Module
):
"""
A simple depth-based image rendering model.
"""
def
__init__
(
self
,
image_size
:
int
=
256
,
bg_color
:
float
=
0.0
,
max_points
:
int
=
-
1
,
):
"""
Initializes a simple DBIR model.
Args:
image_size: The size of the rendered rectangular images.
bg_color: The color of the background.
max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled
to #max_size points without replacement.
"""
super
().
__init__
()
self
.
image_size
=
image_size
self
.
bg_color
=
bg_color
self
.
max_points
=
max_points
def
forward
(
self
,
camera
:
CamerasBase
,
image_rgb
:
torch
.
Tensor
,
depth_map
:
torch
.
Tensor
,
fg_probability
:
torch
.
Tensor
,
frame_type
:
List
[
str
],
**
kwargs
,
)
->
Dict
[
str
,
Any
]:
# TODO: return a namedtuple or dataclass
"""
Given a set of input source cameras images and depth maps, unprojects
all RGBD maps to a colored point cloud and renders into the target views.
Args:
camera: A batch of `N` PyTorch3D cameras.
image_rgb: A batch of `N` images of shape `(N, 3, H, W)`.
depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`.
fg_probability: A batch of `N` foreground probability maps
of shape `(N, 1, H, W)`.
frame_type: A list of `N` strings containing frame type indicators
which specify target and source views.
Returns:
preds: A dict with the following fields:
nvs_prediction: The rendered colors, depth and mask
of the target views.
point_cloud: The point cloud of the scene. It's renders are
stored in `nvs_prediction`.
"""
is_known
=
is_known_frame
(
frame_type
)
is_known_idx
=
torch
.
where
(
is_known
)[
0
]
mask_fg
=
(
fg_probability
>
0.5
).
type_as
(
image_rgb
)
point_cloud
=
get_rgbd_point_cloud
(
camera
[
is_known_idx
],
image_rgb
[
is_known_idx
],
depth_map
[
is_known_idx
],
mask_fg
[
is_known_idx
],
)
pcl_size
=
int
(
point_cloud
.
num_points_per_cloud
())
if
(
self
.
max_points
>
0
)
and
(
pcl_size
>
self
.
max_points
):
prm
=
torch
.
randperm
(
pcl_size
)[:
self
.
max_points
]
point_cloud
=
Pointclouds
(
point_cloud
.
points_padded
()[:,
prm
,
:],
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
features
=
point_cloud
.
features_padded
()[:,
prm
,
:],
)
is_target_idx
=
torch
.
where
(
~
is_known
)[
0
]
depth_render
,
image_render
,
mask_render
=
[],
[],
[]
# render into target frames in a for loop to save memory
for
tgt_idx
in
is_target_idx
:
_image_render
,
_mask_render
,
_depth_render
=
render_point_cloud_pytorch3d
(
camera
[
int
(
tgt_idx
)],
point_cloud
,
render_size
=
(
self
.
image_size
,
self
.
image_size
),
point_radius
=
1e-2
,
topk
=
10
,
bg_color
=
self
.
bg_color
,
)
_image_render
=
_image_render
.
clamp
(
0.0
,
1.0
)
# the mask is the set of pixels with opacity bigger than eps
_mask_render
=
(
_mask_render
>
1e-4
).
float
()
depth_render
.
append
(
_depth_render
)
image_render
.
append
(
_image_render
)
mask_render
.
append
(
_mask_render
)
nvs_prediction
=
NewViewSynthesisPrediction
(
**
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
zip
(
[
"depth_render"
,
"image_render"
,
"mask_render"
],
[
depth_render
,
image_render
,
mask_render
],
)
}
)
preds
=
{
"nvs_prediction"
:
nvs_prediction
,
"point_cloud"
:
point_cloud
,
}
return
preds
pytorch3d/implicitron/models/renderer/base.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
pytorch3d.implicitron.tools.config
import
ReplaceableBase
class
EvaluationMode
(
Enum
):
TRAINING
=
"training"
EVALUATION
=
"evaluation"
class
RenderSamplingMode
(
Enum
):
MASK_SAMPLE
=
"mask_sample"
FULL_GRID
=
"full_grid"
@
dataclass
class
RendererOutput
:
"""
A structure for storing the output of a renderer.
Args:
features: rendered features (usually RGB colors), (B, ..., C) tensor.
depth: rendered ray-termination depth map, in NDC coordinates, (B, ..., 1) tensor.
mask: rendered object mask, values in [0, 1], (B, ..., 1) tensor.
prev_stage: for multi-pass renderers (e.g. in NeRF),
a reference to the output of the previous stage.
normals: surface normals, for renderers that estimate them; (B, ..., 3) tensor.
points: ray-termination points in the world coordinates, (B, ..., 3) tensor.
aux: dict for implementation-specific renderer outputs.
"""
features
:
torch
.
Tensor
depths
:
torch
.
Tensor
masks
:
torch
.
Tensor
prev_stage
:
Optional
[
RendererOutput
]
=
None
normals
:
Optional
[
torch
.
Tensor
]
=
None
points
:
Optional
[
torch
.
Tensor
]
=
None
# TODO: redundant with depths
aux
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
lambda
:
{})
class
ImplicitFunctionWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
:
torch
.
nn
.
Module
):
super
().
__init__
()
self
.
_fn
=
fn
self
.
bound_args
=
{}
def
bind_args
(
self
,
**
bound_args
):
self
.
bound_args
=
bound_args
self
.
_fn
.
on_bind_args
()
def
unbind_args
(
self
):
self
.
bound_args
=
{}
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
_fn
(
*
args
,
**
{
**
kwargs
,
**
self
.
bound_args
})
class
BaseRenderer
(
ABC
,
ReplaceableBase
):
"""
Base class for all Renderer implementations.
"""
def
__init__
(
self
):
super
().
__init__
()
@
abstractmethod
def
forward
(
self
,
ray_bundle
,
implicit_functions
:
List
[
ImplicitFunctionWrapper
],
evaluation_mode
:
EvaluationMode
=
EvaluationMode
.
EVALUATION
,
**
kwargs
)
->
RendererOutput
:
"""
Each Renderer should implement its own forward function
that returns an instance of RendererOutput.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3)
containing the direction vectors of rendering rays.
lengths: A tensor of shape
(minibatch, ..., num_points_per_ray)containing the
lengths at which the ray points are sampled.
The coordinates of the points on the rays are thus computed
as `origins + lengths * directions`.
xys: A tensor of shape
(minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space.
implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the MultiPassEARenderer
allows specifying mulitple values in the list.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
**kwargs: In addition to the name args, custom keyword args can be specified.
For example in the SignedDistanceFunctionRenderer, an object_mask is
required which needs to be passed via the kwargs.
Returns:
instance of RendererOutput
"""
pass
pytorch3d/implicitron/models/renderer/lstm_renderer.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
pytorch3d.implicitron.tools.config
import
registry
from
pytorch3d.renderer
import
RayBundle
from
.base
import
BaseRenderer
,
EvaluationMode
,
ImplicitFunctionWrapper
,
RendererOutput
@
registry
.
register
class
LSTMRenderer
(
BaseRenderer
,
torch
.
nn
.
Module
):
"""
Implements the learnable LSTM raymarching function from SRN [1].
Settings:
num_raymarch_steps: The number of LSTM raymarching steps.
init_depth: Initializes the bias of the last raymarching LSTM layer so that
the farthest point from the camera reaches a far z-plane that
lies `init_depth` units from the camera plane.
init_depth_noise_std: The standard deviation of the random normal noise
added to the initial depth of each marched ray.
hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step.
verbose: If `True`, prints raymarching debug info.
References:
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
"Scene representation networks: Continuous 3d-structure-aware
neural scene representations." NeurIPS 2019.
"""
num_raymarch_steps
:
int
=
10
init_depth
:
float
=
17.0
init_depth_noise_std
:
float
=
5e-4
hidden_size
:
int
=
16
n_feature_channels
:
int
=
256
verbose
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
self
.
_lstm
=
torch
.
nn
.
LSTMCell
(
input_size
=
self
.
n_feature_channels
,
hidden_size
=
self
.
hidden_size
,
)
self
.
_lstm
.
apply
(
_init_recurrent_weights
)
_lstm_forget_gate_init
(
self
.
_lstm
)
self
.
_out_layer
=
torch
.
nn
.
Linear
(
self
.
hidden_size
,
1
)
one_step
=
self
.
init_depth
/
self
.
num_raymarch_steps
self
.
_out_layer
.
bias
.
data
.
fill_
(
one_step
)
self
.
_out_layer
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1e-3
)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
implicit_functions
:
List
[
ImplicitFunctionWrapper
],
evaluation_mode
:
EvaluationMode
=
EvaluationMode
.
EVALUATION
,
**
kwargs
,
)
->
RendererOutput
:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: A single-element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering, specifically the RayPointRefiner and the density_noise_std.
Returns:
instance of RendererOutput
"""
if
len
(
implicit_functions
)
!=
1
:
raise
ValueError
(
"LSTM renderer expects a single implicit function."
)
implicit_function
=
implicit_functions
[
0
]
if
ray_bundle
.
lengths
.
shape
[
-
1
]
!=
1
:
raise
ValueError
(
"LSTM renderer requires a ray-bundle with a single point per ray"
+
" which is the initial raymarching point."
)
# jitter the initial depths
ray_bundle_t
=
ray_bundle
.
_replace
(
lengths
=
ray_bundle
.
lengths
+
torch
.
randn_like
(
ray_bundle
.
lengths
)
*
self
.
init_depth_noise_std
)
states
:
List
[
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]
=
[
None
]
signed_distance
=
torch
.
zeros_like
(
ray_bundle_t
.
lengths
)
raymarch_features
=
None
for
t
in
range
(
self
.
num_raymarch_steps
+
1
):
# move signed_distance along each ray
ray_bundle_t
=
ray_bundle_t
.
_replace
(
lengths
=
ray_bundle_t
.
lengths
+
signed_distance
)
# eval the raymarching function
raymarch_features
,
_
=
implicit_function
(
ray_bundle_t
,
raymarch_features
=
None
,
)
if
self
.
verbose
:
# print some stats
print
(
f
"
{
t
}
: mu=
{
float
(
signed_distance
.
mean
()):
1.2
e
}
;"
+
f
" std=
{
float
(
signed_distance
.
std
()):
1.2
e
}
;"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+
f
" mu_d=
{
float
(
ray_bundle_t
.
lengths
.
mean
()):
1.2
e
}
;"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+
f
" std_d=
{
float
(
ray_bundle_t
.
lengths
.
std
()):
1.2
e
}
;"
)
if
t
==
self
.
num_raymarch_steps
:
break
# run the lstm marcher
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
state_h
,
state_c
=
self
.
_lstm
(
raymarch_features
.
view
(
-
1
,
raymarch_features
.
shape
[
-
1
]),
states
[
-
1
],
)
if
state_h
.
requires_grad
:
state_h
.
register_hook
(
lambda
x
:
x
.
clamp
(
min
=-
10
,
max
=
10
))
# predict the next step size
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
signed_distance
=
self
.
_out_layer
(
state_h
).
view
(
ray_bundle_t
.
lengths
.
shape
)
# log the lstm states
states
.
append
((
state_h
,
state_c
))
opacity_logits
,
features
=
implicit_function
(
raymarch_features
=
raymarch_features
,
ray_bundle
=
ray_bundle_t
,
)
mask
=
torch
.
sigmoid
(
opacity_logits
)
depth
=
ray_bundle_t
.
lengths
*
ray_bundle_t
.
directions
.
norm
(
dim
=-
1
,
keepdim
=
True
)
return
RendererOutput
(
features
=
features
[...,
0
,
:],
depths
=
depth
,
masks
=
mask
[...,
0
,
:],
)
def
_init_recurrent_weights
(
self
)
->
None
:
# copied from SRN codebase
for
m
in
self
.
modules
():
if
type
(
m
)
in
[
torch
.
nn
.
GRU
,
torch
.
nn
.
LSTM
,
torch
.
nn
.
RNN
]:
for
name
,
param
in
m
.
named_parameters
():
if
"weight_ih"
in
name
:
torch
.
nn
.
init
.
kaiming_normal_
(
param
.
data
)
elif
"weight_hh"
in
name
:
torch
.
nn
.
init
.
orthogonal_
(
param
.
data
)
elif
"bias"
in
name
:
param
.
data
.
fill_
(
0
)
def
_lstm_forget_gate_init
(
lstm_layer
)
->
None
:
# copied from SRN codebase
for
name
,
parameter
in
lstm_layer
.
named_parameters
():
if
"bias"
not
in
name
:
continue
n
=
parameter
.
size
(
0
)
start
,
end
=
n
//
4
,
n
//
2
parameter
.
data
[
start
:
end
].
fill_
(
1.0
)
pytorch3d/implicitron/models/renderer/multipass_ea.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Tuple
import
torch
from
pytorch3d.implicitron.tools.config
import
registry
from
.base
import
BaseRenderer
,
EvaluationMode
,
RendererOutput
from
.ray_point_refiner
import
RayPointRefiner
from
.raymarcher
import
GenericRaymarcher
@
registry
.
register
class
MultiPassEmissionAbsorptionRenderer
(
BaseRenderer
,
torch
.
nn
.
Module
):
"""
Implements the multi-pass rendering function, in particular,
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
opacity-based ray-point weights and then optionally (in case more implicit
functions are given) resamples points using importance sampling and evaluates
new weights.
During each ray marching pass, features, depth map, and masks
are integrated: Let o_i be the opacity estimated by the implicit function,
and d_i be the offset between points `i` and `i+1` along the respective ray.
Ray marching is performed using the following equations:
```
ray_opacity_n = cap_fn(sum_i=1^n cap_fn(d_i * o_i)),
weight_n = weight_fn(cap_fn(d_i * o_i), 1 - ray_opacity_{n-1}),
```
and the final rendered quantities are computed by a dot-product of ray values
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
See below for possible values of `cap_fn` and `weight_fn`.
Settings:
n_pts_per_ray_fine_training: The number of points sampled per ray for the
fine rendering pass during training.
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
fine rendering pass during evaluation.
stratified_sampling_coarse_training: Enable/disable stratified sampling during
training.
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
evaluation.
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
after sampling.
bg_color: The background color. A tuple of either 1 element or of D elements,
where D matches the feature dimensionality; it is broadcasted when necessary.
density_noise_std_train: Standard deviation of the noise added to the
opacity field.
capping_function: The capping function of the raymarcher.
Options:
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
- "cap1" (`cap_fn(x) = min(x, 1)`)
Set to "exponential" for the standard Emission Absorption raymarching.
weight_function: The weighting function of the raymarcher.
Options:
- "product" (`weight_fn(w, x) = w * x`)
- "minimum" (`weight_fn(w, x) = min(w, x)`)
Set to "product" for the standard Emission Absorption raymarching.
background_opacity: The raw opacity value (i.e. before exponentiation)
of the background.
blend_output: If `True`, alpha-blends the output renders with the
background color using the rendered opacity mask.
References:
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
fields for view synthesis." ECCV 2020.
"""
n_pts_per_ray_fine_training
:
int
=
64
n_pts_per_ray_fine_evaluation
:
int
=
64
stratified_sampling_coarse_training
:
bool
=
True
stratified_sampling_coarse_evaluation
:
bool
=
False
append_coarse_samples_to_fine
:
bool
=
True
bg_color
:
Tuple
[
float
,
...]
=
(
0.0
,)
density_noise_std_train
:
float
=
0.0
capping_function
:
str
=
"exponential"
# exponential | cap1
weight_function
:
str
=
"product"
# product | minimum
background_opacity
:
float
=
1e10
blend_output
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
self
.
_refiners
=
{
EvaluationMode
.
TRAINING
:
RayPointRefiner
(
n_pts_per_ray
=
self
.
n_pts_per_ray_fine_training
,
random_sampling
=
self
.
stratified_sampling_coarse_training
,
add_input_samples
=
self
.
append_coarse_samples_to_fine
,
),
EvaluationMode
.
EVALUATION
:
RayPointRefiner
(
n_pts_per_ray
=
self
.
n_pts_per_ray_fine_evaluation
,
random_sampling
=
self
.
stratified_sampling_coarse_evaluation
,
add_input_samples
=
self
.
append_coarse_samples_to_fine
,
),
}
self
.
_raymarcher
=
GenericRaymarcher
(
1
,
self
.
bg_color
,
capping_function
=
self
.
capping_function
,
weight_function
=
self
.
weight_function
,
background_opacity
=
self
.
background_opacity
,
blend_output
=
self
.
blend_output
,
)
def
forward
(
self
,
ray_bundle
,
implicit_functions
=
[],
evaluation_mode
:
EvaluationMode
=
EvaluationMode
.
EVALUATION
,
**
kwargs
)
->
RendererOutput
:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: List of ImplicitFunctionWrappers which
define the implicit functions to be used sequentially in
the raymarching step. The output of raymarching with
implicit_functions[n-1] is refined, and then used as
input for raymarching with implicit_functions[n].
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering
Returns:
instance of RendererOutput
"""
if
not
implicit_functions
:
raise
ValueError
(
"EA renderer expects implicit functions"
)
return
self
.
_run_raymarcher
(
ray_bundle
,
implicit_functions
,
None
,
evaluation_mode
,
)
def
_run_raymarcher
(
self
,
ray_bundle
,
implicit_functions
,
prev_stage
,
evaluation_mode
):
density_noise_std
=
(
self
.
density_noise_std_train
if
evaluation_mode
==
EvaluationMode
.
TRAINING
else
0.0
)
features
,
depth
,
mask
,
weights
,
aux
=
self
.
_raymarcher
(
*
implicit_functions
[
0
](
ray_bundle
),
ray_lengths
=
ray_bundle
.
lengths
,
density_noise_std
=
density_noise_std
,
)
output
=
RendererOutput
(
features
=
features
,
depths
=
depth
,
masks
=
mask
,
aux
=
aux
,
prev_stage
=
prev_stage
)
# we may need to make a recursive call
if
len
(
implicit_functions
)
>
1
:
fine_ray_bundle
=
self
.
_refiners
[
evaluation_mode
](
ray_bundle
,
weights
)
output
=
self
.
_run_raymarcher
(
fine_ray_bundle
,
implicit_functions
[
1
:],
output
,
evaluation_mode
,
)
return
output
pytorch3d/implicitron/models/renderer/ray_point_refiner.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
pytorch3d.implicitron.tools.config
import
Configurable
,
expand_args_fields
from
pytorch3d.renderer
import
RayBundle
from
pytorch3d.renderer.implicit.sample_pdf
import
sample_pdf
@
expand_args_fields
# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized.
# pyre-fixme[13]: Attribute `random_sampling` is never initialized.
class
RayPointRefiner
(
Configurable
,
torch
.
nn
.
Module
):
"""
Implements the importance sampling of points along rays.
The input is a `RayBundle` object with a `ray_weights` tensor
which specifies the probabilities of sampling a point along each ray.
This raysampler is used for the fine rendering pass of NeRF.
As such, the forward pass accepts the RayBundle output by the
raysampling of the coarse rendering pass. Hence, it does not
take cameras as input.
Args:
n_pts_per_ray: The number of points to sample along each ray.
random_sampling: If `False`, returns equispaced percentiles of the
distribution defined by the input weights, otherwise performs
sampling from that distribution.
add_input_samples: Concatenates and returns the sampled values
together with the input samples.
"""
n_pts_per_ray
:
int
random_sampling
:
bool
add_input_samples
:
bool
=
True
def
__post_init__
(
self
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
input_ray_bundle
:
RayBundle
,
ray_weights
:
torch
.
Tensor
,
**
kwargs
,
)
->
RayBundle
:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
z_vals
=
input_ray_bundle
.
lengths
with
torch
.
no_grad
():
z_vals_mid
=
torch
.
lerp
(
z_vals
[...,
1
:],
z_vals
[...,
:
-
1
],
0.5
)
z_samples
=
sample_pdf
(
z_vals_mid
.
view
(
-
1
,
z_vals_mid
.
shape
[
-
1
]),
ray_weights
.
view
(
-
1
,
ray_weights
.
shape
[
-
1
])[...,
1
:
-
1
],
self
.
n_pts_per_ray
,
det
=
not
self
.
random_sampling
,
).
view
(
*
z_vals
.
shape
[:
-
1
],
self
.
n_pts_per_ray
)
if
self
.
add_input_samples
:
# Add the new samples to the input ones.
z_vals
=
torch
.
cat
((
z_vals
,
z_samples
),
dim
=-
1
)
else
:
z_vals
=
z_samples
# Resort by depth.
z_vals
,
_
=
torch
.
sort
(
z_vals
,
dim
=-
1
)
return
RayBundle
(
origins
=
input_ray_bundle
.
origins
,
directions
=
input_ray_bundle
.
directions
,
lengths
=
z_vals
,
xys
=
input_ray_bundle
.
xys
,
)
pytorch3d/implicitron/models/renderer/ray_sampler.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
field
from
typing
import
Optional
,
Tuple
import
torch
from
pytorch3d.implicitron.tools
import
camera_utils
from
pytorch3d.implicitron.tools.config
import
Configurable
from
pytorch3d.renderer
import
NDCMultinomialRaysampler
,
RayBundle
from
pytorch3d.renderer.cameras
import
CamerasBase
from
.base
import
EvaluationMode
,
RenderSamplingMode
class
RaySampler
(
Configurable
,
torch
.
nn
.
Module
):
"""
Samples a fixed number of points along rays which are in turn sampled for
each camera in a batch.
This class utilizes `NDCMultinomialRaysampler` which allows to either
randomly sample rays from an input foreground saliency mask
(`RenderSamplingMode.MASK_SAMPLE`), or on a rectangular image grid
(`RenderSamplingMode.FULL_GRID`). The sampling mode can be set separately
for training and evaluation by setting `self.sampling_mode_training`
and `self.sampling_mode_training` accordingly.
The class allows two modes of sampling points along the rays:
1) Sampling between fixed near and far z-planes:
Active when `self.scene_extent <= 0`, samples points along each ray
with approximately uniform spacing of z-coordinates between
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
2) Adaptive near/far plane estimation around the world scene center:
Active when `self.scene_extent > 0`. Samples points on each
ray between near and far planes whose depths are determined based on
the distance from the camera center to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Similar to the sampling mode, the sampling parameters can be set separately
for training and evaluation.
Settings:
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
If `scene_extent <= 0`, the raysampler samples points between
`self.min_depth` and `self.max_depth` depths instead.
sampling_mode_training: The ray sampling mode for training. This should be a str
option from the RenderSamplingMode Enum
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
"""
image_width
:
int
=
400
image_height
:
int
=
400
scene_center
:
Tuple
[
float
,
float
,
float
]
=
field
(
default_factory
=
lambda
:
(
0.0
,
0.0
,
0.0
)
)
scene_extent
:
float
=
0.0
sampling_mode_training
:
str
=
"mask_sample"
sampling_mode_evaluation
:
str
=
"full_grid"
n_pts_per_ray_training
:
int
=
64
n_pts_per_ray_evaluation
:
int
=
64
n_rays_per_image_sampled_from_mask
:
int
=
1024
min_depth
:
float
=
0.1
max_depth
:
float
=
8.0
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training
:
bool
=
True
stratified_point_sampling_evaluation
:
bool
=
False
def
__post_init__
(
self
):
super
().
__init__
()
self
.
scene_center
=
torch
.
FloatTensor
(
self
.
scene_center
)
self
.
_sampling_mode
=
{
EvaluationMode
.
TRAINING
:
RenderSamplingMode
(
self
.
sampling_mode_training
),
EvaluationMode
.
EVALUATION
:
RenderSamplingMode
(
self
.
sampling_mode_evaluation
),
}
self
.
_raysamplers
=
{
EvaluationMode
.
TRAINING
:
NDCMultinomialRaysampler
(
image_width
=
self
.
image_width
,
image_height
=
self
.
image_height
,
n_pts_per_ray
=
self
.
n_pts_per_ray_training
,
min_depth
=
self
.
min_depth
,
max_depth
=
self
.
max_depth
,
n_rays_per_image
=
self
.
n_rays_per_image_sampled_from_mask
if
self
.
_sampling_mode
[
EvaluationMode
.
TRAINING
]
==
RenderSamplingMode
.
MASK_SAMPLE
else
None
,
unit_directions
=
True
,
stratified_sampling
=
self
.
stratified_point_sampling_training
,
),
EvaluationMode
.
EVALUATION
:
NDCMultinomialRaysampler
(
image_width
=
self
.
image_width
,
image_height
=
self
.
image_height
,
n_pts_per_ray
=
self
.
n_pts_per_ray_evaluation
,
min_depth
=
self
.
min_depth
,
max_depth
=
self
.
max_depth
,
n_rays_per_image
=
self
.
n_rays_per_image_sampled_from_mask
if
self
.
_sampling_mode
[
EvaluationMode
.
EVALUATION
]
==
RenderSamplingMode
.
MASK_SAMPLE
else
None
,
unit_directions
=
True
,
stratified_sampling
=
self
.
stratified_point_sampling_evaluation
,
),
}
def
forward
(
self
,
cameras
:
CamerasBase
,
evaluation_mode
:
EvaluationMode
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
RayBundle
:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
sample_mask
=
None
if
(
# pyre-fixme[29]
self
.
_sampling_mode
[
evaluation_mode
]
==
RenderSamplingMode
.
MASK_SAMPLE
and
mask
is
not
None
):
sample_mask
=
torch
.
nn
.
functional
.
interpolate
(
mask
,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[int]`.
size
=
[
self
.
image_height
,
self
.
image_width
],
mode
=
"nearest"
,
)[:,
0
]
if
self
.
scene_extent
>
0.0
:
# Override the min/max depth set in initialization based on the
# input cameras.
min_depth
,
max_depth
=
camera_utils
.
get_min_max_depth_bounds
(
cameras
,
self
.
scene_center
,
self
.
scene_extent
)
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle
=
self
.
_raysamplers
[
evaluation_mode
](
cameras
=
cameras
,
mask
=
sample_mask
,
min_depth
=
float
(
min_depth
[
0
])
if
self
.
scene_extent
>
0.0
else
None
,
max_depth
=
float
(
max_depth
[
0
])
if
self
.
scene_extent
>
0.0
else
None
,
)
return
ray_bundle
pytorch3d/implicitron/models/renderer/ray_tracing.py
0 → 100644
View file @
cdd2142d
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr
# Copyright (c) 2020 Lior Yariv
from
typing
import
Any
,
Callable
,
Tuple
import
torch
import
torch.nn
as
nn
from
pytorch3d.implicitron.tools.config
import
Configurable
class
RayTracing
(
Configurable
,
nn
.
Module
):
"""
Finds the intersection points of rays with the implicit surface defined
by a signed distance function (SDF). The algorithm follows the pipeline:
1. Initialise start and end points on rays by the intersections with
the circumscribing sphere.
2. Run sphere tracing from both ends.
3. Divide the untraced segments of non-convergent rays into uniform
intervals and find the one with the sign transition.
4. Run the secant method to estimate the point of the sign transition.
Args:
object_bounding_sphere: The radius of the initial sphere circumscribing
the object.
sdf_threshold: Absolute SDF value small enough for the sphere tracer
to consider it a surface.
line_search_step: Length of the backward correction on sphere tracing
iterations.
line_step_iters: Number of backward correction iterations.
sphere_tracing_iters: Maximum number of sphere tracing iterations
(the actual number of iterations may be smaller if all ray
intersections are found).
n_steps: Number of intervals sampled for unconvergent rays.
n_secant_steps: Number of iterations in the secant algorithm.
"""
object_bounding_sphere
:
float
=
1.0
sdf_threshold
:
float
=
5.0e-5
line_search_step
:
float
=
0.5
line_step_iters
:
int
=
1
sphere_tracing_iters
:
int
=
10
n_steps
:
int
=
100
n_secant_steps
:
int
=
8
def
__post_init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
sdf
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
cam_loc
:
torch
.
Tensor
,
object_mask
:
torch
.
BoolTensor
,
ray_directions
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
sdf: A callable that takes a (N, 3) tensor of points and returns
a tensor of (N,) SDF values.
cam_loc: A tensor of (B, N, 3) ray origins.
object_mask: A (N, 3) tensor of indicators whether a sampled pixel
corresponds to the rendered object or background.
ray_directions: A tensor of (B, N, 3) ray directions.
Returns:
curr_start_points: A tensor of (B*N, 3) found intersection points
with the implicit surface.
network_object_mask: A tensor of (B*N,) indicators denoting whether
intersections were found.
acc_start_dis: A tensor of (B*N,) distances from the ray origins
to intersrection points.
"""
batch_size
,
num_pixels
,
_
=
ray_directions
.
shape
device
=
cam_loc
.
device
sphere_intersections
,
mask_intersect
=
_get_sphere_intersection
(
cam_loc
,
ray_directions
,
r
=
self
.
object_bounding_sphere
)
(
curr_start_points
,
unfinished_mask_start
,
acc_start_dis
,
acc_end_dis
,
min_dis
,
max_dis
,
)
=
self
.
sphere_tracing
(
batch_size
,
num_pixels
,
sdf
,
cam_loc
,
ray_directions
,
mask_intersect
,
sphere_intersections
,
)
network_object_mask
=
acc_start_dis
<
acc_end_dis
# The non convergent rays should be handled by the sampler
sampler_mask
=
unfinished_mask_start
sampler_net_obj_mask
=
torch
.
zeros_like
(
sampler_mask
,
dtype
=
torch
.
bool
,
device
=
device
)
if
sampler_mask
.
sum
()
>
0
:
sampler_min_max
=
torch
.
zeros
((
batch_size
,
num_pixels
,
2
),
device
=
device
)
sampler_min_max
.
reshape
(
-
1
,
2
)[
sampler_mask
,
0
]
=
acc_start_dis
[
sampler_mask
]
sampler_min_max
.
reshape
(
-
1
,
2
)[
sampler_mask
,
1
]
=
acc_end_dis
[
sampler_mask
]
sampler_pts
,
sampler_net_obj_mask
,
sampler_dists
=
self
.
ray_sampler
(
sdf
,
cam_loc
,
object_mask
,
ray_directions
,
sampler_min_max
,
sampler_mask
)
curr_start_points
[
sampler_mask
]
=
sampler_pts
[
sampler_mask
]
acc_start_dis
[
sampler_mask
]
=
sampler_dists
[
sampler_mask
]
network_object_mask
[
sampler_mask
]
=
sampler_net_obj_mask
[
sampler_mask
]
if
not
self
.
training
:
return
curr_start_points
,
network_object_mask
,
acc_start_dis
# in case we are training, we are updating curr_start_points and acc_start_dis for
ray_directions
=
ray_directions
.
reshape
(
-
1
,
3
)
mask_intersect
=
mask_intersect
.
reshape
(
-
1
)
object_mask
=
object_mask
.
reshape
(
-
1
)
in_mask
=
~
network_object_mask
&
object_mask
&
~
sampler_mask
out_mask
=
~
object_mask
&
~
sampler_mask
# pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
mask_left_out
=
(
in_mask
|
out_mask
)
&
~
mask_intersect
if
(
mask_left_out
.
sum
()
>
0
):
# project the origin to the not intersect points on the sphere
cam_left_out
=
cam_loc
.
reshape
(
-
1
,
3
)[
mask_left_out
]
rays_left_out
=
ray_directions
[
mask_left_out
]
acc_start_dis
[
mask_left_out
]
=
-
torch
.
bmm
(
rays_left_out
.
view
(
-
1
,
1
,
3
),
cam_left_out
.
view
(
-
1
,
3
,
1
)
).
squeeze
()
curr_start_points
[
mask_left_out
]
=
(
cam_left_out
+
acc_start_dis
[
mask_left_out
].
unsqueeze
(
1
)
*
rays_left_out
)
mask
=
(
in_mask
|
out_mask
)
&
mask_intersect
if
mask
.
sum
()
>
0
:
min_dis
[
network_object_mask
&
out_mask
]
=
acc_start_dis
[
network_object_mask
&
out_mask
]
min_mask_points
,
min_mask_dist
=
self
.
minimal_sdf_points
(
sdf
,
cam_loc
,
ray_directions
,
mask
,
min_dis
,
max_dis
)
curr_start_points
[
mask
]
=
min_mask_points
acc_start_dis
[
mask
]
=
min_mask_dist
return
curr_start_points
,
network_object_mask
,
acc_start_dis
def
sphere_tracing
(
self
,
batch_size
:
int
,
num_pixels
:
int
,
sdf
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
cam_loc
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
,
mask_intersect
:
torch
.
Tensor
,
sphere_intersections
:
torch
.
Tensor
,
)
->
Tuple
[
Any
,
Any
,
Any
,
Any
,
Any
,
Any
]:
"""
Run sphere tracing algorithm for max iterations
from both sides of unit sphere intersection
Args:
batch_size:
num_pixels:
sdf:
cam_loc:
ray_directions:
mask_intersect:
sphere_intersections:
Returns:
curr_start_points:
unfinished_mask_start:
acc_start_dis:
acc_end_dis:
min_dis:
max_dis:
"""
device
=
cam_loc
.
device
sphere_intersections_points
=
(
cam_loc
[...,
None
,
:]
+
sphere_intersections
[...,
None
]
*
ray_directions
[...,
None
,
:]
)
unfinished_mask_start
=
mask_intersect
.
reshape
(
-
1
).
clone
()
unfinished_mask_end
=
mask_intersect
.
reshape
(
-
1
).
clone
()
# Initialize start current points
curr_start_points
=
torch
.
zeros
(
batch_size
*
num_pixels
,
3
,
device
=
device
)
curr_start_points
[
unfinished_mask_start
]
=
sphere_intersections_points
[
:,
:,
0
,
:
].
reshape
(
-
1
,
3
)[
unfinished_mask_start
]
acc_start_dis
=
torch
.
zeros
(
batch_size
*
num_pixels
,
device
=
device
)
acc_start_dis
[
unfinished_mask_start
]
=
sphere_intersections
.
reshape
(
-
1
,
2
)[
unfinished_mask_start
,
0
]
# Initialize end current points
curr_end_points
=
torch
.
zeros
(
batch_size
*
num_pixels
,
3
,
device
=
device
)
curr_end_points
[
unfinished_mask_end
]
=
sphere_intersections_points
[
:,
:,
1
,
:
].
reshape
(
-
1
,
3
)[
unfinished_mask_end
]
acc_end_dis
=
torch
.
zeros
(
batch_size
*
num_pixels
,
device
=
device
)
acc_end_dis
[
unfinished_mask_end
]
=
sphere_intersections
.
reshape
(
-
1
,
2
)[
unfinished_mask_end
,
1
]
# Initialise min and max depth
min_dis
=
acc_start_dis
.
clone
()
max_dis
=
acc_end_dis
.
clone
()
# Iterate on the rays (from both sides) till finding a surface
iters
=
0
# TODO: sdf should also pass info about batches
next_sdf_start
=
torch
.
zeros_like
(
acc_start_dis
)
next_sdf_start
[
unfinished_mask_start
]
=
sdf
(
curr_start_points
[
unfinished_mask_start
]
)
next_sdf_end
=
torch
.
zeros_like
(
acc_end_dis
)
next_sdf_end
[
unfinished_mask_end
]
=
sdf
(
curr_end_points
[
unfinished_mask_end
])
while
True
:
# Update sdf
curr_sdf_start
=
torch
.
zeros_like
(
acc_start_dis
)
curr_sdf_start
[
unfinished_mask_start
]
=
next_sdf_start
[
unfinished_mask_start
]
curr_sdf_start
[
curr_sdf_start
<=
self
.
sdf_threshold
]
=
0
curr_sdf_end
=
torch
.
zeros_like
(
acc_end_dis
)
curr_sdf_end
[
unfinished_mask_end
]
=
next_sdf_end
[
unfinished_mask_end
]
curr_sdf_end
[
curr_sdf_end
<=
self
.
sdf_threshold
]
=
0
# Update masks
unfinished_mask_start
=
unfinished_mask_start
&
(
curr_sdf_start
>
self
.
sdf_threshold
)
unfinished_mask_end
=
unfinished_mask_end
&
(
curr_sdf_end
>
self
.
sdf_threshold
)
if
(
unfinished_mask_start
.
sum
()
==
0
and
unfinished_mask_end
.
sum
()
==
0
)
or
iters
==
self
.
sphere_tracing_iters
:
break
iters
+=
1
# Make step
# Update distance
acc_start_dis
=
acc_start_dis
+
curr_sdf_start
acc_end_dis
=
acc_end_dis
-
curr_sdf_end
# Update points
curr_start_points
=
(
cam_loc
+
acc_start_dis
.
reshape
(
batch_size
,
num_pixels
,
1
)
*
ray_directions
).
reshape
(
-
1
,
3
)
curr_end_points
=
(
cam_loc
+
acc_end_dis
.
reshape
(
batch_size
,
num_pixels
,
1
)
*
ray_directions
).
reshape
(
-
1
,
3
)
# Fix points which wrongly crossed the surface
next_sdf_start
=
torch
.
zeros_like
(
acc_start_dis
)
next_sdf_start
[
unfinished_mask_start
]
=
sdf
(
curr_start_points
[
unfinished_mask_start
]
)
next_sdf_end
=
torch
.
zeros_like
(
acc_end_dis
)
next_sdf_end
[
unfinished_mask_end
]
=
sdf
(
curr_end_points
[
unfinished_mask_end
]
)
not_projected_start
=
next_sdf_start
<
0
not_projected_end
=
next_sdf_end
<
0
not_proj_iters
=
0
while
(
not_projected_start
.
sum
()
>
0
or
not_projected_end
.
sum
()
>
0
)
and
not_proj_iters
<
self
.
line_step_iters
:
# Step backwards
acc_start_dis
[
not_projected_start
]
-=
(
(
1
-
self
.
line_search_step
)
/
(
2
**
not_proj_iters
)
)
*
curr_sdf_start
[
not_projected_start
]
curr_start_points
[
not_projected_start
]
=
(
cam_loc
+
acc_start_dis
.
reshape
(
batch_size
,
num_pixels
,
1
)
*
ray_directions
).
reshape
(
-
1
,
3
)[
not_projected_start
]
acc_end_dis
[
not_projected_end
]
+=
(
(
1
-
self
.
line_search_step
)
/
(
2
**
not_proj_iters
)
)
*
curr_sdf_end
[
not_projected_end
]
curr_end_points
[
not_projected_end
]
=
(
cam_loc
+
acc_end_dis
.
reshape
(
batch_size
,
num_pixels
,
1
)
*
ray_directions
).
reshape
(
-
1
,
3
)[
not_projected_end
]
# Calc sdf
next_sdf_start
[
not_projected_start
]
=
sdf
(
curr_start_points
[
not_projected_start
]
)
next_sdf_end
[
not_projected_end
]
=
sdf
(
curr_end_points
[
not_projected_end
]
)
# Update mask
not_projected_start
=
next_sdf_start
<
0
not_projected_end
=
next_sdf_end
<
0
not_proj_iters
+=
1
unfinished_mask_start
=
unfinished_mask_start
&
(
acc_start_dis
<
acc_end_dis
)
unfinished_mask_end
=
unfinished_mask_end
&
(
acc_start_dis
<
acc_end_dis
)
return
(
curr_start_points
,
unfinished_mask_start
,
acc_start_dis
,
acc_end_dis
,
min_dis
,
max_dis
,
)
def
ray_sampler
(
self
,
sdf
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
cam_loc
:
torch
.
Tensor
,
object_mask
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
,
sampler_min_max
:
torch
.
Tensor
,
sampler_mask
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Sample the ray in a given range and run secant on rays which have sign transition.
Args:
sdf:
cam_loc:
object_mask:
ray_directions:
sampler_min_max:
sampler_mask:
Returns:
"""
batch_size
,
num_pixels
,
_
=
ray_directions
.
shape
device
=
cam_loc
.
device
n_total_pxl
=
batch_size
*
num_pixels
sampler_pts
=
torch
.
zeros
(
n_total_pxl
,
3
,
device
=
device
)
sampler_dists
=
torch
.
zeros
(
n_total_pxl
,
device
=
device
)
intervals_dist
=
torch
.
linspace
(
0
,
1
,
steps
=
self
.
n_steps
,
device
=
device
).
view
(
1
,
1
,
-
1
)
pts_intervals
=
sampler_min_max
[:,
:,
0
].
unsqueeze
(
-
1
)
+
intervals_dist
*
(
sampler_min_max
[:,
:,
1
]
-
sampler_min_max
[:,
:,
0
]
).
unsqueeze
(
-
1
)
points
=
(
cam_loc
[...,
None
,
:]
+
pts_intervals
[...,
None
]
*
ray_directions
[...,
None
,
:]
)
# Get the non convergent rays
mask_intersect_idx
=
torch
.
nonzero
(
sampler_mask
).
flatten
()
points
=
points
.
reshape
((
-
1
,
self
.
n_steps
,
3
))[
sampler_mask
,
:,
:]
pts_intervals
=
pts_intervals
.
reshape
((
-
1
,
self
.
n_steps
))[
sampler_mask
]
sdf_val_all
=
[]
for
pnts
in
torch
.
split
(
points
.
reshape
(
-
1
,
3
),
100000
,
dim
=
0
):
sdf_val_all
.
append
(
sdf
(
pnts
))
sdf_val
=
torch
.
cat
(
sdf_val_all
).
reshape
(
-
1
,
self
.
n_steps
)
tmp
=
torch
.
sign
(
sdf_val
)
*
torch
.
arange
(
self
.
n_steps
,
0
,
-
1
,
device
=
device
,
dtype
=
torch
.
float32
).
reshape
(
1
,
self
.
n_steps
)
# Force argmin to return the first min value
sampler_pts_ind
=
torch
.
argmin
(
tmp
,
-
1
)
sampler_pts
[
mask_intersect_idx
]
=
points
[
torch
.
arange
(
points
.
shape
[
0
]),
sampler_pts_ind
,
:
]
sampler_dists
[
mask_intersect_idx
]
=
pts_intervals
[
torch
.
arange
(
pts_intervals
.
shape
[
0
]),
sampler_pts_ind
]
true_surface_pts
=
object_mask
.
reshape
(
-
1
)[
sampler_mask
]
net_surface_pts
=
sdf_val
[
torch
.
arange
(
sdf_val
.
shape
[
0
]),
sampler_pts_ind
]
<
0
# take points with minimal SDF value for P_out pixels
p_out_mask
=
~
(
true_surface_pts
&
net_surface_pts
)
n_p_out
=
p_out_mask
.
sum
()
if
n_p_out
>
0
:
out_pts_idx
=
torch
.
argmin
(
sdf_val
[
p_out_mask
,
:],
-
1
)
sampler_pts
[
mask_intersect_idx
[
p_out_mask
]]
=
points
[
p_out_mask
,
:,
:][
torch
.
arange
(
n_p_out
),
out_pts_idx
,
:
]
sampler_dists
[
mask_intersect_idx
[
p_out_mask
]]
=
pts_intervals
[
p_out_mask
,
:
][
torch
.
arange
(
n_p_out
),
out_pts_idx
]
# Get Network object mask
sampler_net_obj_mask
=
sampler_mask
.
clone
()
sampler_net_obj_mask
[
mask_intersect_idx
[
~
net_surface_pts
]]
=
False
# Run Secant method
secant_pts
=
(
net_surface_pts
&
true_surface_pts
if
self
.
training
else
net_surface_pts
)
n_secant_pts
=
secant_pts
.
sum
()
if
n_secant_pts
>
0
:
# Get secant z predictions
z_high
=
pts_intervals
[
torch
.
arange
(
pts_intervals
.
shape
[
0
]),
sampler_pts_ind
][
secant_pts
]
sdf_high
=
sdf_val
[
torch
.
arange
(
sdf_val
.
shape
[
0
]),
sampler_pts_ind
][
secant_pts
]
z_low
=
pts_intervals
[
secant_pts
][
torch
.
arange
(
n_secant_pts
),
sampler_pts_ind
[
secant_pts
]
-
1
]
sdf_low
=
sdf_val
[
secant_pts
][
torch
.
arange
(
n_secant_pts
),
sampler_pts_ind
[
secant_pts
]
-
1
]
cam_loc_secant
=
cam_loc
.
reshape
(
-
1
,
3
)[
mask_intersect_idx
[
secant_pts
]]
ray_directions_secant
=
ray_directions
.
reshape
((
-
1
,
3
))[
mask_intersect_idx
[
secant_pts
]
]
z_pred_secant
=
self
.
secant
(
sdf_low
,
sdf_high
,
z_low
,
z_high
,
cam_loc_secant
,
ray_directions_secant
,
# pyre-fixme[6]: For 7th param expected `Module` but got `(Tensor)
# -> Tensor`.
sdf
,
)
# Get points
sampler_pts
[
mask_intersect_idx
[
secant_pts
]]
=
(
cam_loc_secant
+
z_pred_secant
.
unsqueeze
(
-
1
)
*
ray_directions_secant
)
sampler_dists
[
mask_intersect_idx
[
secant_pts
]]
=
z_pred_secant
return
sampler_pts
,
sampler_net_obj_mask
,
sampler_dists
def
secant
(
self
,
sdf_low
:
torch
.
Tensor
,
sdf_high
:
torch
.
Tensor
,
z_low
:
torch
.
Tensor
,
z_high
:
torch
.
Tensor
,
cam_loc
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
,
sdf
:
nn
.
Module
,
)
->
torch
.
Tensor
:
"""
Runs the secant method for interval [z_low, z_high] for n_secant_steps
"""
z_pred
=
-
sdf_low
*
(
z_high
-
z_low
)
/
(
sdf_high
-
sdf_low
)
+
z_low
for
_
in
range
(
self
.
n_secant_steps
):
p_mid
=
cam_loc
+
z_pred
.
unsqueeze
(
-
1
)
*
ray_directions
sdf_mid
=
sdf
(
p_mid
)
ind_low
=
sdf_mid
>
0
if
ind_low
.
sum
()
>
0
:
z_low
[
ind_low
]
=
z_pred
[
ind_low
]
sdf_low
[
ind_low
]
=
sdf_mid
[
ind_low
]
ind_high
=
sdf_mid
<
0
if
ind_high
.
sum
()
>
0
:
z_high
[
ind_high
]
=
z_pred
[
ind_high
]
sdf_high
[
ind_high
]
=
sdf_mid
[
ind_high
]
z_pred
=
-
sdf_low
*
(
z_high
-
z_low
)
/
(
sdf_high
-
sdf_low
)
+
z_low
return
z_pred
def
minimal_sdf_points
(
self
,
sdf
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
cam_loc
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
min_dis
:
torch
.
Tensor
,
max_dis
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Find points with minimal SDF value on rays for P_out pixels
"""
n_mask_points
=
mask
.
sum
()
n
=
self
.
n_steps
steps
=
torch
.
empty
(
n
,
device
=
cam_loc
.
device
).
uniform_
(
0.0
,
1.0
)
mask_max_dis
=
max_dis
[
mask
].
unsqueeze
(
-
1
)
mask_min_dis
=
min_dis
[
mask
].
unsqueeze
(
-
1
)
steps
=
(
steps
.
unsqueeze
(
0
).
repeat
(
n_mask_points
,
1
)
*
(
mask_max_dis
-
mask_min_dis
)
+
mask_min_dis
)
mask_points
=
cam_loc
.
reshape
(
-
1
,
3
)[
mask
]
mask_rays
=
ray_directions
[
mask
,
:]
mask_points_all
=
mask_points
.
unsqueeze
(
1
).
repeat
(
1
,
n
,
1
)
+
steps
.
unsqueeze
(
-
1
)
*
mask_rays
.
unsqueeze
(
1
).
repeat
(
1
,
n
,
1
)
points
=
mask_points_all
.
reshape
(
-
1
,
3
)
mask_sdf_all
=
[]
for
pnts
in
torch
.
split
(
points
,
100000
,
dim
=
0
):
mask_sdf_all
.
append
(
sdf
(
pnts
))
mask_sdf_all
=
torch
.
cat
(
mask_sdf_all
).
reshape
(
-
1
,
n
)
min_vals
,
min_idx
=
mask_sdf_all
.
min
(
-
1
)
min_mask_points
=
mask_points_all
.
reshape
(
-
1
,
n
,
3
)[
torch
.
arange
(
0
,
n_mask_points
),
min_idx
]
min_mask_dist
=
steps
.
reshape
(
-
1
,
n
)[
torch
.
arange
(
0
,
n_mask_points
),
min_idx
]
return
min_mask_points
,
min_mask_dist
# TODO: support variable origins
def
_get_sphere_intersection
(
cam_loc
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
,
r
:
float
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Input: n_images x 3 ; n_images x n_rays x 3
# Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays
n_imgs
,
n_pix
,
_
=
ray_directions
.
shape
device
=
cam_loc
.
device
# cam_loc = cam_loc.unsqueeze(-1)
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
ray_cam_dot
=
(
ray_directions
*
cam_loc
).
sum
(
-
1
)
# n_images x n_rays
under_sqrt
=
ray_cam_dot
**
2
-
(
cam_loc
.
norm
(
2
,
dim
=-
1
)
**
2
-
r
**
2
)
under_sqrt
=
under_sqrt
.
reshape
(
-
1
)
mask_intersect
=
under_sqrt
>
0
sphere_intersections
=
torch
.
zeros
(
n_imgs
*
n_pix
,
2
,
device
=
device
)
sphere_intersections
[
mask_intersect
]
=
torch
.
sqrt
(
under_sqrt
[
mask_intersect
]
).
unsqueeze
(
-
1
)
*
torch
.
tensor
([
-
1.0
,
1.0
],
device
=
device
)
sphere_intersections
[
mask_intersect
]
-=
ray_cam_dot
.
reshape
(
-
1
)[
mask_intersect
].
unsqueeze
(
-
1
)
sphere_intersections
=
sphere_intersections
.
reshape
(
n_imgs
,
n_pix
,
2
)
sphere_intersections
=
sphere_intersections
.
clamp_min
(
0.0
)
mask_intersect
=
mask_intersect
.
reshape
(
n_imgs
,
n_pix
)
return
sphere_intersections
,
mask_intersect
pytorch3d/implicitron/models/renderer/raymarcher.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Union
import
torch
from
pytorch3d.renderer.implicit.raymarching
import
_check_raymarcher_inputs
_TTensor
=
torch
.
Tensor
class
GenericRaymarcher
(
torch
.
nn
.
Module
):
"""
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
and NeuralVolumes' Accumulative ray marcher. It additionally returns
the rendering weights that can be used in the NVS pipeline to carry out
the importance ray-sampling in the refining pass.
Different from `EmissionAbsorptionRaymarcher`, it takes raw
(non-exponentiated) densities.
Args:
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
"""
def
__init__
(
self
,
surface_thickness
:
int
=
1
,
bg_color
:
Union
[
Tuple
[
float
,
...],
_TTensor
]
=
(
0.0
,),
capping_function
:
str
=
"exponential"
,
# exponential | cap1
weight_function
:
str
=
"product"
,
# product | minimum
background_opacity
:
float
=
0.0
,
density_relu
:
bool
=
True
,
blend_output
:
bool
=
True
,
):
"""
Args:
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super
().
__init__
()
self
.
surface_thickness
=
surface_thickness
self
.
density_relu
=
density_relu
self
.
background_opacity
=
background_opacity
self
.
blend_output
=
blend_output
if
not
isinstance
(
bg_color
,
torch
.
Tensor
):
bg_color
=
torch
.
tensor
(
bg_color
)
if
bg_color
.
ndim
!=
1
:
raise
ValueError
(
f
"bg_color (shape
{
bg_color
.
shape
}
) should be a 1D tensor"
)
self
.
register_buffer
(
"_bg_color"
,
bg_color
,
persistent
=
False
)
self
.
_capping_function
:
Callable
[[
_TTensor
],
_TTensor
]
=
{
"exponential"
:
lambda
x
:
1.0
-
torch
.
exp
(
-
x
),
"cap1"
:
lambda
x
:
x
.
clamp
(
max
=
1.0
),
}[
capping_function
]
self
.
_weight_function
:
Callable
[[
_TTensor
,
_TTensor
],
_TTensor
]
=
{
"product"
:
lambda
curr
,
acc
:
curr
*
acc
,
"minimum"
:
lambda
curr
,
acc
:
torch
.
minimum
(
curr
,
acc
),
}[
weight_function
]
def
forward
(
self
,
rays_densities
:
torch
.
Tensor
,
rays_features
:
torch
.
Tensor
,
aux
:
Dict
[
str
,
Any
],
ray_lengths
:
torch
.
Tensor
,
density_noise_std
:
float
=
0.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray, 1)`.
rays_features: Per-ray feature values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
density_noise_std: the magnitude of the noise added to densities.
Returns:
features: A tensor of shape `(..., feature_dim)` containing
the rendered features for each ray.
depth: A tensor of shape `(..., 1)` containing estimated depth.
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
weights: A tensor of shape `(..., n_points_per_ray)` containing
the ray-specific non-negative opacity weights. In general, they
don't sum to 1 but do not overcome it, i.e.
`(weights.sum(dim=-1) <= 1.0).all()` holds.
"""
_check_raymarcher_inputs
(
rays_densities
,
rays_features
,
ray_lengths
,
z_can_be_none
=
True
,
features_can_be_none
=
False
,
density_1d
=
True
,
)
deltas
=
torch
.
cat
(
(
ray_lengths
[...,
1
:]
-
ray_lengths
[...,
:
-
1
],
self
.
background_opacity
*
torch
.
ones_like
(
ray_lengths
[...,
:
1
]),
),
dim
=-
1
,
)
rays_densities
=
rays_densities
[...,
0
]
if
density_noise_std
>
0.0
:
rays_densities
=
(
rays_densities
+
torch
.
randn_like
(
rays_densities
)
*
density_noise_std
)
if
self
.
density_relu
:
rays_densities
=
torch
.
relu
(
rays_densities
)
weighted_densities
=
deltas
*
rays_densities
capped_densities
=
self
.
_capping_function
(
weighted_densities
)
rays_opacities
=
self
.
_capping_function
(
torch
.
cumsum
(
weighted_densities
,
dim
=-
1
)
)
opacities
=
rays_opacities
[...,
-
1
:]
absorption_shifted
=
(
-
rays_opacities
+
1.0
).
roll
(
self
.
surface_thickness
,
dims
=-
1
)
absorption_shifted
[...,
:
self
.
surface_thickness
]
=
1.0
weights
=
self
.
_weight_function
(
capped_densities
,
absorption_shifted
)
features
=
(
weights
[...,
None
]
*
rays_features
).
sum
(
dim
=-
2
)
depth
=
(
weights
*
ray_lengths
)[...,
None
].
sum
(
dim
=-
2
)
alpha
=
opacities
if
self
.
blend_output
else
1
if
self
.
_bg_color
.
shape
[
-
1
]
not
in
[
1
,
features
.
shape
[
-
1
]]:
raise
ValueError
(
"Wrong number of background color channels."
)
features
=
alpha
*
features
+
(
1
-
opacities
)
*
self
.
_bg_color
return
features
,
depth
,
opacities
,
weights
,
aux
pytorch3d/implicitron/models/renderer/rgb_net.py
0 → 100644
View file @
cdd2142d
# @lint-ignore-every LICENSELINT
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
import
torch
from
pytorch3d.renderer.implicit
import
HarmonicEmbedding
,
RayBundle
from
torch
import
nn
class
RayNormalColoringNetwork
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
feature_vector_size
=
3
,
mode
=
"idr"
,
d_in
=
9
,
d_out
=
3
,
dims
=
(
512
,
512
,
512
,
512
),
weight_norm
=
True
,
n_harmonic_functions_dir
=
0
,
pooled_feature_dim
=
0
,
):
super
().
__init__
()
self
.
mode
=
mode
self
.
output_dimensions
=
d_out
dims
=
[
d_in
+
feature_vector_size
]
+
list
(
dims
)
+
[
d_out
]
self
.
embedview_fn
=
None
if
n_harmonic_functions_dir
>
0
:
self
.
embedview_fn
=
HarmonicEmbedding
(
n_harmonic_functions_dir
,
append_input
=
True
)
dims
[
0
]
+=
self
.
embedview_fn
.
get_output_dim
()
-
3
if
pooled_feature_dim
>
0
:
print
(
"Pooled features in rendering network."
)
dims
[
0
]
+=
pooled_feature_dim
self
.
num_layers
=
len
(
dims
)
layers
=
[]
for
layer_idx
in
range
(
self
.
num_layers
-
1
):
out_dim
=
dims
[
layer_idx
+
1
]
lin
=
nn
.
Linear
(
dims
[
layer_idx
],
out_dim
)
if
weight_norm
:
lin
=
nn
.
utils
.
weight_norm
(
lin
)
layers
.
append
(
lin
)
self
.
linear_layers
=
torch
.
nn
.
ModuleList
(
layers
)
self
.
relu
=
nn
.
ReLU
()
self
.
tanh
=
nn
.
Tanh
()
def
forward
(
self
,
feature_vectors
:
torch
.
Tensor
,
points
,
normals
,
ray_bundle
:
RayBundle
,
masks
=
None
,
pooling_fn
=
None
,
):
if
masks
is
not
None
and
not
masks
.
any
():
return
torch
.
zeros_like
(
normals
)
view_dirs
=
ray_bundle
.
directions
if
masks
is
not
None
:
# in case of IDR, other outputs are passed here after applying the mask
view_dirs
=
view_dirs
.
reshape
(
view_dirs
.
shape
[
0
],
-
1
,
3
)[
:,
masks
.
reshape
(
-
1
)
]
if
self
.
embedview_fn
is
not
None
:
view_dirs
=
self
.
embedview_fn
(
view_dirs
)
if
self
.
mode
==
"idr"
:
rendering_input
=
torch
.
cat
(
[
points
,
view_dirs
,
normals
,
feature_vectors
],
dim
=-
1
)
elif
self
.
mode
==
"no_view_dir"
:
rendering_input
=
torch
.
cat
([
points
,
normals
,
feature_vectors
],
dim
=-
1
)
elif
self
.
mode
==
"no_normal"
:
rendering_input
=
torch
.
cat
([
points
,
view_dirs
,
feature_vectors
],
dim
=-
1
)
else
:
raise
ValueError
(
f
"Unsupported rendering mode:
{
self
.
mode
}
"
)
if
pooling_fn
is
not
None
:
featspool
=
pooling_fn
(
points
[
None
])[
0
]
rendering_input
=
torch
.
cat
((
rendering_input
,
featspool
),
dim
=-
1
)
x
=
rendering_input
for
layer_idx
in
range
(
self
.
num_layers
-
1
):
x
=
self
.
linear_layers
[
layer_idx
](
x
)
if
layer_idx
<
self
.
num_layers
-
2
:
x
=
self
.
relu
(
x
)
x
=
self
.
tanh
(
x
)
return
x
pytorch3d/implicitron/models/renderer/sdf_renderer.py
0 → 100644
View file @
cdd2142d
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import
functools
import
math
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
omegaconf
import
DictConfig
from
pytorch3d.implicitron.tools.config
import
get_default_args_field
,
registry
from
pytorch3d.implicitron.tools.utils
import
evaluating
from
pytorch3d.renderer
import
RayBundle
from
.base
import
BaseRenderer
,
EvaluationMode
,
ImplicitFunctionWrapper
,
RendererOutput
from
.ray_tracing
import
RayTracing
from
.rgb_net
import
RayNormalColoringNetwork
@
registry
.
register
class
SignedDistanceFunctionRenderer
(
BaseRenderer
,
torch
.
nn
.
Module
):
render_features_dimensions
:
int
=
3
ray_tracer_args
:
DictConfig
=
get_default_args_field
(
RayTracing
)
ray_normal_coloring_network_args
:
DictConfig
=
get_default_args_field
(
RayNormalColoringNetwork
)
bg_color
:
Tuple
[
float
,
...]
=
(
0.0
,)
soft_mask_alpha
:
float
=
50.0
def
__post_init__
(
self
,
):
super
().
__init__
()
render_features_dimensions
=
self
.
render_features_dimensions
if
len
(
self
.
bg_color
)
not
in
[
1
,
render_features_dimensions
]:
raise
ValueError
(
f
"Background color should have
{
render_features_dimensions
}
entries."
)
self
.
ray_tracer
=
RayTracing
(
**
self
.
ray_tracer_args
)
self
.
object_bounding_sphere
=
self
.
ray_tracer_args
.
get
(
"object_bounding_sphere"
)
self
.
ray_normal_coloring_network_args
[
"feature_vector_size"
]
=
render_features_dimensions
self
.
_rgb_network
=
RayNormalColoringNetwork
(
**
self
.
ray_normal_coloring_network_args
)
self
.
register_buffer
(
"_bg_color"
,
torch
.
tensor
(
self
.
bg_color
),
persistent
=
False
)
def
forward
(
self
,
ray_bundle
:
RayBundle
,
implicit_functions
:
List
[
ImplicitFunctionWrapper
],
evaluation_mode
:
EvaluationMode
=
EvaluationMode
.
EVALUATION
,
object_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
RendererOutput
:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: single element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
kwargs:
object_mask: BoolTensor, denoting the silhouette of the object.
This is a required keyword argument for SignedDistanceFunctionRenderer
Returns:
instance of RendererOutput
"""
if
len
(
implicit_functions
)
!=
1
:
raise
ValueError
(
"SignedDistanceFunctionRenderer supports only single pass."
)
if
object_mask
is
None
:
raise
ValueError
(
"Expected object_mask to be provided in the kwargs"
)
object_mask
=
object_mask
.
bool
()
implicit_function
=
implicit_functions
[
0
]
implicit_function_gradient
=
functools
.
partial
(
gradient
,
implicit_function
)
# object_mask: silhouette of the object
batch_size
,
*
spatial_size
,
_
=
ray_bundle
.
lengths
.
shape
num_pixels
=
math
.
prod
(
spatial_size
)
cam_loc
=
ray_bundle
.
origins
.
reshape
(
batch_size
,
-
1
,
3
)
ray_dirs
=
ray_bundle
.
directions
.
reshape
(
batch_size
,
-
1
,
3
)
object_mask
=
object_mask
.
reshape
(
batch_size
,
-
1
)
with
torch
.
no_grad
(),
evaluating
(
implicit_function
):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
points
,
network_object_mask
,
dists
=
self
.
ray_tracer
(
sdf
=
lambda
x
:
implicit_function
(
x
)[
:,
0
],
# TODO: get rid of this wrapper
cam_loc
=
cam_loc
,
object_mask
=
object_mask
,
ray_directions
=
ray_dirs
,
)
# TODO: below, cam_loc might as well be different
depth
=
dists
.
reshape
(
batch_size
,
num_pixels
,
1
)
points
=
(
cam_loc
+
depth
*
ray_dirs
).
reshape
(
-
1
,
3
)
sdf_output
=
implicit_function
(
points
)[:,
0
:
1
]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs
=
ray_dirs
.
reshape
(
-
1
,
3
)
object_mask
=
object_mask
.
reshape
(
-
1
)
# TODO: move it to loss computation
if
evaluation_mode
==
EvaluationMode
.
TRAINING
:
surface_mask
=
network_object_mask
&
object_mask
surface_points
=
points
[
surface_mask
]
surface_dists
=
dists
[
surface_mask
].
unsqueeze
(
-
1
)
surface_ray_dirs
=
ray_dirs
[
surface_mask
]
surface_cam_loc
=
cam_loc
.
reshape
(
-
1
,
3
)[
surface_mask
]
surface_output
=
sdf_output
[
surface_mask
]
N
=
surface_points
.
shape
[
0
]
# Sample points for the eikonal loss
# pyre-fixme[9]
eik_bounding_box
:
float
=
self
.
object_bounding_sphere
n_eik_points
=
batch_size
*
num_pixels
//
2
eikonal_points
=
torch
.
empty
(
n_eik_points
,
3
,
device
=
self
.
_bg_color
.
device
).
uniform_
(
-
eik_bounding_box
,
eik_bounding_box
)
eikonal_pixel_points
=
points
.
clone
()
eikonal_pixel_points
=
eikonal_pixel_points
.
detach
()
eikonal_points
=
torch
.
cat
([
eikonal_points
,
eikonal_pixel_points
],
0
)
points_all
=
torch
.
cat
([
surface_points
,
eikonal_points
],
dim
=
0
)
output
=
implicit_function
(
surface_points
)
surface_sdf_values
=
output
[
:
N
,
0
:
1
].
detach
()
# how is it different from sdf_output?
g
=
implicit_function_gradient
(
points_all
)
surface_points_grad
=
g
[:
N
,
0
,
:].
clone
().
detach
()
grad_theta
=
g
[
N
:,
0
,
:]
differentiable_surface_points
=
_sample_network
(
surface_output
,
surface_sdf_values
,
surface_points_grad
,
surface_dists
,
surface_cam_loc
,
surface_ray_dirs
,
)
else
:
surface_mask
=
network_object_mask
differentiable_surface_points
=
points
[
surface_mask
]
grad_theta
=
None
empty_render
=
differentiable_surface_points
.
shape
[
0
]
==
0
features
=
implicit_function
(
differentiable_surface_points
)[
None
,
:,
1
:]
normals_full
=
features
.
new_zeros
(
batch_size
,
*
spatial_size
,
3
,
requires_grad
=
empty_render
)
render_full
=
(
features
.
new_ones
(
batch_size
,
*
spatial_size
,
self
.
render_features_dimensions
,
requires_grad
=
empty_render
,
)
*
self
.
_bg_color
)
mask_full
=
features
.
new_ones
(
batch_size
,
*
spatial_size
,
1
,
requires_grad
=
empty_render
)
if
not
empty_render
:
normals
=
implicit_function_gradient
(
differentiable_surface_points
)[
None
,
:,
0
,
:
]
normals_full
.
view
(
-
1
,
3
)[
surface_mask
]
=
normals
render_full
.
view
(
-
1
,
self
.
render_features_dimensions
)[
surface_mask
]
=
self
.
_rgb_network
(
# pyre-fixme[29]:
features
,
differentiable_surface_points
[
None
],
normals
,
ray_bundle
,
surface_mask
[
None
,
:,
None
],
pooling_fn
=
None
,
# TODO
)
mask_full
.
view
(
-
1
,
1
)[
~
surface_mask
]
=
torch
.
sigmoid
(
-
self
.
soft_mask_alpha
*
sdf_output
[
~
surface_mask
]
)
# scatter points with surface_mask
points_full
=
ray_bundle
.
origins
.
detach
().
clone
()
points_full
.
view
(
-
1
,
3
)[
surface_mask
]
=
differentiable_surface_points
# TODO: it is sparse here but otherwise dense
return
RendererOutput
(
features
=
render_full
,
normals
=
normals_full
,
depths
=
depth
.
reshape
(
batch_size
,
*
spatial_size
,
1
),
masks
=
mask_full
,
# this is a differentiable approximation, see (7) in the paper
points
=
points_full
,
aux
=
{
"grad_theta"
:
grad_theta
},
# TODO: will be moved to eikonal loss
# TODO: do we need sdf_output, grad_theta? Only for loss probably
)
def
_sample_network
(
surface_output
,
surface_sdf_values
,
surface_points_grad
,
surface_dists
,
surface_cam_loc
,
surface_ray_dirs
,
eps
=
1e-4
,
):
# t -> t(theta)
surface_ray_dirs_0
=
surface_ray_dirs
.
detach
()
surface_points_dot
=
torch
.
bmm
(
surface_points_grad
.
view
(
-
1
,
1
,
3
),
surface_ray_dirs_0
.
view
(
-
1
,
3
,
1
)
).
squeeze
(
-
1
)
dot_sign
=
(
surface_points_dot
>=
0
).
to
(
surface_points_dot
)
*
2
-
1
surface_dists_theta
=
surface_dists
-
(
surface_output
-
surface_sdf_values
)
/
(
surface_points_dot
.
abs
().
clip
(
eps
)
*
dot_sign
)
# t(theta) -> x(theta,c,v)
surface_points_theta_c_v
=
surface_cam_loc
+
surface_dists_theta
*
surface_ray_dirs
return
surface_points_theta_c_v
@
torch
.
enable_grad
()
def
gradient
(
module
,
x
):
x
.
requires_grad_
(
True
)
y
=
module
.
forward
(
x
)[:,
:
1
]
d_output
=
torch
.
ones_like
(
y
,
requires_grad
=
False
,
device
=
y
.
device
)
gradients
=
torch
.
autograd
.
grad
(
outputs
=
y
,
inputs
=
x
,
grad_outputs
=
d_output
,
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
,
)[
0
]
return
gradients
.
unsqueeze
(
1
)
pytorch3d/implicitron/models/resnet_feature_extractor.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
logging
import
math
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
Fu
import
torchvision
from
pytorch3d.implicitron.tools.config
import
Configurable
logger
=
logging
.
getLogger
(
__name__
)
MASK_FEATURE_NAME
=
"mask"
IMAGE_FEATURE_NAME
=
"image"
_FEAT_DIMS
=
{
"resnet18"
:
(
64
,
128
,
256
,
512
),
"resnet34"
:
(
64
,
128
,
256
,
512
),
"resnet50"
:
(
256
,
512
,
1024
,
2048
),
"resnet101"
:
(
256
,
512
,
1024
,
2048
),
"resnet152"
:
(
256
,
512
,
1024
,
2048
),
}
_RESNET_MEAN
=
[
0.485
,
0.456
,
0.406
]
_RESNET_STD
=
[
0.229
,
0.224
,
0.225
]
class
ResNetFeatureExtractor
(
Configurable
,
torch
.
nn
.
Module
):
"""
Implements an image feature extractor. Depending on the settings allows
to extract:
- deep features: A CNN ResNet backbone from torchvision (with/without
pretrained weights) which extracts deep features.
- masks: Segmentation masks.
- images: Raw input RGB images.
Settings:
name: name of the resnet backbone (from torchvision)
pretrained: If true, will load the pretrained weights
stages: List of stages from which to extract features.
Features from each stage are returned as key value
pairs in the forward function
normalize_image: If set will normalize the RGB values of
the image based on the Resnet mean/std
image_rescale: If not 1.0, this rescale factor will be
used to resize the image
first_max_pool: If set, a max pool layer is added after the first
convolutional layer
proj_dim: The number of output channels for the convolutional layers
l2_norm: If set, l2 normalization is applied to the extracted features
add_masks: If set, the masks will be saved in the output dictionary
add_images: If set, the images will be saved in the output dictionary
global_average_pool: If set, global average pooling step is performed
feature_rescale: If not 1.0, this rescale factor will be used to
rescale the output features
"""
name
:
str
=
"resnet34"
pretrained
:
bool
=
True
stages
:
Tuple
[
int
,
...]
=
(
1
,
2
,
3
,
4
)
normalize_image
:
bool
=
True
image_rescale
:
float
=
128
/
800.0
first_max_pool
:
bool
=
True
proj_dim
:
int
=
32
l2_norm
:
bool
=
True
add_masks
:
bool
=
True
add_images
:
bool
=
True
global_average_pool
:
bool
=
False
# this can simulate global/non-spacial features
feature_rescale
:
float
=
1.0
def
__post_init__
(
self
):
super
().
__init__
()
if
self
.
normalize_image
:
# register buffers needed to normalize the image
for
k
,
v
in
((
"_resnet_mean"
,
_RESNET_MEAN
),
(
"_resnet_std"
,
_RESNET_STD
)):
self
.
register_buffer
(
k
,
torch
.
FloatTensor
(
v
).
view
(
1
,
3
,
1
,
1
),
persistent
=
False
,
)
self
.
_feat_dim
=
{}
if
len
(
self
.
stages
)
==
0
:
# do not extract any resnet features
pass
else
:
net
=
getattr
(
torchvision
.
models
,
self
.
name
)(
pretrained
=
self
.
pretrained
)
if
self
.
first_max_pool
:
self
.
stem
=
torch
.
nn
.
Sequential
(
net
.
conv1
,
net
.
bn1
,
net
.
relu
,
net
.
maxpool
)
else
:
self
.
stem
=
torch
.
nn
.
Sequential
(
net
.
conv1
,
net
.
bn1
,
net
.
relu
)
self
.
max_stage
=
max
(
self
.
stages
)
self
.
layers
=
torch
.
nn
.
ModuleList
()
self
.
proj_layers
=
torch
.
nn
.
ModuleList
()
for
stage
in
range
(
self
.
max_stage
):
stage_name
=
f
"layer
{
stage
+
1
}
"
feature_name
=
self
.
_get_resnet_stage_feature_name
(
stage
)
if
(
stage
+
1
)
in
self
.
stages
:
if
(
self
.
proj_dim
>
0
and
_FEAT_DIMS
[
self
.
name
][
stage
]
>
self
.
proj_dim
):
proj
=
torch
.
nn
.
Conv2d
(
_FEAT_DIMS
[
self
.
name
][
stage
],
self
.
proj_dim
,
1
,
1
,
bias
=
True
,
)
self
.
_feat_dim
[
feature_name
]
=
self
.
proj_dim
else
:
proj
=
torch
.
nn
.
Identity
()
self
.
_feat_dim
[
feature_name
]
=
_FEAT_DIMS
[
self
.
name
][
stage
]
else
:
proj
=
torch
.
nn
.
Identity
()
self
.
proj_layers
.
append
(
proj
)
self
.
layers
.
append
(
getattr
(
net
,
stage_name
))
if
self
.
add_masks
:
self
.
_feat_dim
[
MASK_FEATURE_NAME
]
=
1
if
self
.
add_images
:
self
.
_feat_dim
[
IMAGE_FEATURE_NAME
]
=
3
logger
.
info
(
f
"Feat extractor total dim =
{
self
.
get_feat_dims
()
}
"
)
self
.
stages
=
set
(
self
.
stages
)
# convert to set for faster "in"
def
_get_resnet_stage_feature_name
(
self
,
stage
)
->
str
:
return
f
"res_layer_
{
stage
+
1
}
"
def
_resnet_normalize_image
(
self
,
img
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
img
-
self
.
_resnet_mean
)
/
self
.
_resnet_std
def
get_feat_dims
(
self
,
size_dict
:
bool
=
False
):
if
size_dict
:
return
copy
.
deepcopy
(
self
.
_feat_dim
)
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `values`.
return
sum
(
self
.
_feat_dim
.
values
())
def
forward
(
self
,
imgs
:
torch
.
Tensor
,
masks
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Dict
[
Any
,
torch
.
Tensor
]:
"""
Args:
imgs: A batch of input images of shape `(B, 3, H, W)`.
masks: A batch of input masks of shape `(B, 3, H, W)`.
Returns:
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
"""
out_feats
=
{}
imgs_input
=
imgs
if
self
.
image_rescale
!=
1.0
:
imgs_resized
=
Fu
.
interpolate
(
imgs_input
,
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
# got `float`.
scale_factor
=
self
.
image_rescale
,
mode
=
"bilinear"
,
)
else
:
imgs_resized
=
imgs_input
if
self
.
normalize_image
:
imgs_normed
=
self
.
_resnet_normalize_image
(
imgs_resized
)
else
:
imgs_normed
=
imgs_resized
if
len
(
self
.
stages
)
>
0
:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
# is not a function.
feats
=
self
.
stem
(
imgs_normed
)
# pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T1]]` but
# got `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd param expected `Iterable[Variable[_T2]]` but
# got `Union[Tensor, Module]`.
for
stage
,
(
layer
,
proj
)
in
enumerate
(
zip
(
self
.
layers
,
self
.
proj_layers
)):
feats
=
layer
(
feats
)
# just a sanity check below
assert
feats
.
shape
[
1
]
==
_FEAT_DIMS
[
self
.
name
][
stage
]
if
(
stage
+
1
)
in
self
.
stages
:
f
=
proj
(
feats
)
if
self
.
global_average_pool
:
f
=
f
.
mean
(
dims
=
(
2
,
3
))
if
self
.
l2_norm
:
normfac
=
1.0
/
math
.
sqrt
(
len
(
self
.
stages
))
f
=
Fu
.
normalize
(
f
,
dim
=
1
)
*
normfac
feature_name
=
self
.
_get_resnet_stage_feature_name
(
stage
)
out_feats
[
feature_name
]
=
f
if
self
.
add_masks
:
assert
masks
is
not
None
out_feats
[
MASK_FEATURE_NAME
]
=
masks
if
self
.
add_images
:
assert
imgs_input
is
not
None
out_feats
[
IMAGE_FEATURE_NAME
]
=
imgs_resized
if
self
.
feature_rescale
!=
1.0
:
out_feats
=
{
k
:
self
.
feature_rescale
*
f
for
k
,
f
in
out_feats
.
items
()}
# pyre-fixme[7]: Incompatible return type, expected `Dict[typing.Any, Tensor]`
# but got `Dict[typing.Any, float]`
return
out_feats
pytorch3d/implicitron/models/view_pooling/feature_aggregation.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
from
typing
import
Dict
,
Optional
,
Sequence
,
Union
import
torch
import
torch.nn.functional
as
F
from
pytorch3d.implicitron.models.view_pooling.view_sampling
import
(
cameras_points_cartesian_product
,
)
from
pytorch3d.implicitron.tools.config
import
ReplaceableBase
,
registry
from
pytorch3d.ops
import
wmean
from
pytorch3d.renderer.cameras
import
CamerasBase
class
ReductionFunction
(
Enum
):
AVG
=
"avg"
# simple average
MAX
=
"max"
# maximum
STD
=
"std"
# standard deviation
STD_AVG
=
"std_avg"
# average of per-dimension standard deviations
class
FeatureAggregatorBase
(
ABC
,
ReplaceableBase
):
"""
Base class for aggregating features.
Typically, the aggregated features and their masks are output by `ViewSampler`
which samples feature tensors extracted from a set of source images.
Settings:
exclude_target_view: If `True`/`False`, enables/disables pooling
from target view to itself.
exclude_target_view_mask_features: If `True`,
mask the features from the target view before aggregation
concatenate_output: If `True`,
concatenate the aggregated features into a single tensor,
otherwise return a dictionary mapping feature names to tensors.
"""
exclude_target_view
:
bool
=
True
exclude_target_view_mask_features
:
bool
=
True
concatenate_output
:
bool
=
True
@
abstractmethod
def
forward
(
self
,
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
masks_sampled
:
torch
.
Tensor
,
camera
:
Optional
[
CamerasBase
]
=
None
,
pts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, reduce_dim, n_samples, sum(dim_1, ... dim_N))`
containing the concatenation of the aggregated features `feats_sampled`.
`reduce_dim` depends on the specific feature aggregator
implementation and typically equals 1 or `n_source_views`.
If `concatenate_output==False`, the aggregator does not concatenate
the aggregated features and returns a dictionary of per-feature
aggregations `{f_i: t_i_aggregated}` instead. Each `t_i_aggregated`
is of shape `(minibatch, reduce_dim, n_samples, aggr_dim_i)`.
"""
raise
NotImplementedError
()
@
registry
.
register
class
IdentityFeatureAggregator
(
torch
.
nn
.
Module
,
FeatureAggregatorBase
):
"""
This aggregator does not perform any feature aggregation. Depending on the
settings the aggregator allows to mask target view features and concatenate
the outputs.
"""
def
__post_init__
(
self
):
super
().
__init__
()
def
get_aggregated_feature_dim
(
self
,
feats
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
int
]):
return
_get_reduction_aggregator_feature_dim
(
feats
,
[])
def
forward
(
self
,
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
masks_sampled
:
torch
.
Tensor
,
camera
:
Optional
[
CamerasBase
]
=
None
,
pts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if
self
.
exclude_target_view_mask_features
:
feats_sampled
=
_mask_target_view_features
(
feats_sampled
)
feats_aggregated
=
feats_sampled
if
self
.
concatenate_output
:
feats_aggregated
=
torch
.
cat
(
tuple
(
feats_aggregated
.
values
()),
dim
=-
1
)
return
feats_aggregated
@
registry
.
register
class
ReductionFeatureAggregator
(
torch
.
nn
.
Module
,
FeatureAggregatorBase
):
"""
Aggregates using a set of predefined `reduction_functions` and concatenates
the results of each aggregation function along the
channel dimension. The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s` that reduce the
the stack of source-view-specific features to a single feature.
"""
reduction_functions
:
Sequence
[
ReductionFunction
]
=
(
ReductionFunction
.
AVG
,
ReductionFunction
.
STD
,
)
def
__post_init__
(
self
):
super
().
__init__
()
def
get_aggregated_feature_dim
(
self
,
feats
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
int
]):
return
_get_reduction_aggregator_feature_dim
(
feats
,
self
.
reduction_functions
)
def
forward
(
self
,
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
masks_sampled
:
torch
.
Tensor
,
camera
:
Optional
[
CamerasBase
]
=
None
,
pts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape `(minibatch, 1, n_samples, aggr_dim_i)`.
"""
pts_batch
,
n_cameras
=
masks_sampled
.
shape
[:
2
]
if
self
.
exclude_target_view_mask_features
:
feats_sampled
=
_mask_target_view_features
(
feats_sampled
)
sampling_mask
=
_get_view_sampling_mask
(
n_cameras
,
pts_batch
,
masks_sampled
.
device
,
self
.
exclude_target_view
,
)
aggr_weigths
=
masks_sampled
*
sampling_mask
feats_aggregated
=
{
k
:
_avgmaxstd_reduction_function
(
f
,
aggr_weigths
,
dim
=
1
,
reduction_functions
=
self
.
reduction_functions
,
)
for
k
,
f
in
feats_sampled
.
items
()
}
if
self
.
concatenate_output
:
feats_aggregated
=
torch
.
cat
(
tuple
(
feats_aggregated
.
values
()),
dim
=-
1
)
return
feats_aggregated
@
registry
.
register
class
AngleWeightedReductionFeatureAggregator
(
torch
.
nn
.
Module
,
FeatureAggregatorBase
):
"""
Performs a weighted aggregation using a set of predefined `reduction_functions`
and concatenates the results of each aggregation function along the
channel dimension. The weights are proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s that reduce the
the stack of source-view-specific features to a single feature.
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
"""
reduction_functions
:
Sequence
[
ReductionFunction
]
=
(
ReductionFunction
.
AVG
,
ReductionFunction
.
STD
,
)
weight_by_ray_angle_gamma
:
float
=
1.0
min_ray_angle_weight
:
float
=
0.1
def
__post_init__
(
self
):
super
().
__init__
()
def
get_aggregated_feature_dim
(
self
,
feats
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
int
]):
return
_get_reduction_aggregator_feature_dim
(
feats
,
self
.
reduction_functions
)
def
forward
(
self
,
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
masks_sampled
:
torch
.
Tensor
,
camera
:
Optional
[
CamerasBase
]
=
None
,
pts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if
camera
is
None
:
raise
ValueError
(
"camera cannot be None for angle weighted aggregation"
)
if
pts
is
None
:
raise
ValueError
(
"Points cannot be None for angle weighted aggregation"
)
pts_batch
,
n_cameras
=
masks_sampled
.
shape
[:
2
]
if
self
.
exclude_target_view_mask_features
:
feats_sampled
=
_mask_target_view_features
(
feats_sampled
)
view_sampling_mask
=
_get_view_sampling_mask
(
n_cameras
,
pts_batch
,
masks_sampled
.
device
,
self
.
exclude_target_view
,
)
aggr_weights
=
_get_angular_reduction_weights
(
view_sampling_mask
,
masks_sampled
,
camera
,
pts
,
self
.
min_ray_angle_weight
,
self
.
weight_by_ray_angle_gamma
,
)
assert
torch
.
isfinite
(
aggr_weights
).
all
()
feats_aggregated
=
{
k
:
_avgmaxstd_reduction_function
(
f
,
aggr_weights
,
dim
=
1
,
reduction_functions
=
self
.
reduction_functions
,
)
for
k
,
f
in
feats_sampled
.
items
()
}
if
self
.
concatenate_output
:
feats_aggregated
=
torch
.
cat
(
tuple
(
feats_aggregated
.
values
()),
dim
=-
1
)
return
feats_aggregated
@
registry
.
register
class
AngleWeightedIdentityFeatureAggregator
(
torch
.
nn
.
Module
,
FeatureAggregatorBase
):
"""
This aggregator does not perform any feature aggregation. It only weights
the features by the weights proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
Settings:
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
Additionally the aggregator allows to mask target view features and to concatenate
the outputs.
"""
weight_by_ray_angle_gamma
:
float
=
1.0
min_ray_angle_weight
:
float
=
0.1
def
__post_init__
(
self
):
super
().
__init__
()
def
get_aggregated_feature_dim
(
self
,
feats
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
int
]):
return
_get_reduction_aggregator_feature_dim
(
feats
,
[])
def
forward
(
self
,
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
masks_sampled
:
torch
.
Tensor
,
camera
:
Optional
[
CamerasBase
]
=
None
,
pts
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, n_source_views, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if
camera
is
None
:
raise
ValueError
(
"camera cannot be None for angle weighted aggregation"
)
if
pts
is
None
:
raise
ValueError
(
"Points cannot be None for angle weighted aggregation"
)
pts_batch
,
n_cameras
=
masks_sampled
.
shape
[:
2
]
if
self
.
exclude_target_view_mask_features
:
feats_sampled
=
_mask_target_view_features
(
feats_sampled
)
view_sampling_mask
=
_get_view_sampling_mask
(
n_cameras
,
pts_batch
,
masks_sampled
.
device
,
self
.
exclude_target_view
,
)
aggr_weights
=
_get_angular_reduction_weights
(
view_sampling_mask
,
masks_sampled
,
camera
,
pts
,
self
.
min_ray_angle_weight
,
self
.
weight_by_ray_angle_gamma
,
)
feats_aggregated
=
{
k
:
f
*
aggr_weights
[...,
None
]
for
k
,
f
in
feats_sampled
.
items
()
}
if
self
.
concatenate_output
:
feats_aggregated
=
torch
.
cat
(
tuple
(
feats_aggregated
.
values
()),
dim
=-
1
)
return
feats_aggregated
def
_get_reduction_aggregator_feature_dim
(
feats_or_feats_dim
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
int
],
reduction_functions
:
Sequence
[
ReductionFunction
],
):
if
isinstance
(
feats_or_feats_dim
,
int
):
feat_dim
=
feats_or_feats_dim
else
:
feat_dim
=
int
(
sum
(
f
.
shape
[
1
]
for
f
in
feats_or_feats_dim
.
values
()))
if
len
(
reduction_functions
)
==
0
:
return
feat_dim
return
sum
(
_get_reduction_function_output_dim
(
reduction_function
,
feat_dim
,
)
for
reduction_function
in
reduction_functions
)
def
_get_reduction_function_output_dim
(
reduction_function
:
ReductionFunction
,
feat_dim
:
int
,
)
->
int
:
if
reduction_function
==
ReductionFunction
.
STD_AVG
:
return
1
else
:
return
feat_dim
def
_get_view_sampling_mask
(
n_cameras
:
int
,
pts_batch
:
int
,
device
:
Union
[
str
,
torch
.
device
],
exclude_target_view
:
bool
,
):
return
(
-
torch
.
eye
(
n_cameras
,
device
=
device
,
dtype
=
torch
.
float32
)
*
float
(
exclude_target_view
)
+
1.0
)[:
pts_batch
]
def
_mask_target_view_features
(
feats_sampled
:
Dict
[
str
,
torch
.
Tensor
],
):
# mask out the sampled features to be sure we dont use them
# anywhere later
one_feature_sampled
=
next
(
iter
(
feats_sampled
.
values
()))
pts_batch
,
n_cameras
=
one_feature_sampled
.
shape
[:
2
]
view_sampling_mask
=
_get_view_sampling_mask
(
n_cameras
,
pts_batch
,
one_feature_sampled
.
device
,
True
,
)
view_sampling_mask
=
view_sampling_mask
.
view
(
pts_batch
,
n_cameras
,
*
([
1
]
*
(
one_feature_sampled
.
ndim
-
2
))
)
return
{
k
:
f
*
view_sampling_mask
for
k
,
f
in
feats_sampled
.
items
()}
def
_get_angular_reduction_weights
(
view_sampling_mask
:
torch
.
Tensor
,
masks_sampled
:
torch
.
Tensor
,
camera
:
CamerasBase
,
pts
:
torch
.
Tensor
,
min_ray_angle_weight
:
float
,
weight_by_ray_angle_gamma
:
float
,
):
aggr_weights
=
masks_sampled
.
clone
()[...,
0
]
assert
not
any
(
v
is
None
for
v
in
[
camera
,
pts
])
angle_weight
=
_get_ray_angle_weights
(
camera
,
pts
,
min_ray_angle_weight
,
weight_by_ray_angle_gamma
,
)
assert
torch
.
isfinite
(
angle_weight
).
all
()
# multiply the final aggr weights with ray angles
view_sampling_mask
=
view_sampling_mask
.
view
(
*
view_sampling_mask
.
shape
[:
2
],
*
([
1
]
*
(
aggr_weights
.
ndim
-
2
))
)
aggr_weights
=
(
aggr_weights
*
angle_weight
.
reshape_as
(
aggr_weights
)
*
view_sampling_mask
)
return
aggr_weights
def
_get_ray_dir_dot_prods
(
camera
:
CamerasBase
,
pts
:
torch
.
Tensor
):
n_cameras
=
camera
.
R
.
shape
[
0
]
pts_batch
=
pts
.
shape
[
0
]
camera_rep
,
pts_rep
=
cameras_points_cartesian_product
(
camera
,
pts
)
# does not produce nans randomly unlike get_camera_center() below
cam_centers_rep
=
-
torch
.
bmm
(
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep
.
T
[:,
None
],
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep
.
R
.
permute
(
0
,
2
,
1
),
).
reshape
(
-
1
,
*
([
1
]
*
(
pts
.
ndim
-
2
)),
3
)
# cam_centers_rep = camera_rep.get_camera_center().reshape(
# -1, *([1]*(pts.ndim - 2)), 3
# )
ray_dirs
=
F
.
normalize
(
pts_rep
-
cam_centers_rep
,
dim
=-
1
)
# camera_rep = [ pts_rep = [
# camera[0] pts[0],
# camera[0] pts[1],
# camera[0] ...,
# ... pts[batch_pts-1],
# camera[1] pts[0],
# camera[1] pts[1],
# camera[1] ...,
# ... pts[batch_pts-1],
# ... ...,
# camera[n_cameras-1] pts[0],
# camera[n_cameras-1] pts[1],
# camera[n_cameras-1] ...,
# ... pts[batch_pts-1],
# ] ]
ray_dirs_reshape
=
ray_dirs
.
view
(
n_cameras
,
pts_batch
,
-
1
,
3
)
# [
# [pts_0 in cam_0, pts_1 in cam_0, ..., pts_m in cam_0],
# [pts_0 in cam_1, pts_1 in cam_1, ..., pts_m in cam_1],
# ...
# [pts_0 in cam_n, pts_1 in cam_n, ..., pts_m in cam_n],
# ]
ray_dirs_pts
=
torch
.
stack
([
ray_dirs_reshape
[
i
,
i
]
for
i
in
range
(
pts_batch
)])
ray_dir_dot_prods
=
(
ray_dirs_pts
[
None
]
*
ray_dirs_reshape
).
sum
(
dim
=-
1
)
# pts_batch x n_cameras x n_pts
return
ray_dir_dot_prods
.
transpose
(
0
,
1
)
def
_get_ray_angle_weights
(
camera
:
CamerasBase
,
pts
:
torch
.
Tensor
,
min_ray_angle_weight
:
float
,
weight_by_ray_angle_gamma
:
float
,
):
ray_dir_dot_prods
=
_get_ray_dir_dot_prods
(
camera
,
pts
)
# pts_batch x n_cameras x ... x 3
angle_weight_01
=
ray_dir_dot_prods
*
0.5
+
0.5
# [-1, 1] to [0, 1]
angle_weight
=
(
angle_weight_01
+
min_ray_angle_weight
)
**
weight_by_ray_angle_gamma
return
angle_weight
def
_avgmaxstd_reduction_function
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
reduction_functions
:
Sequence
[
ReductionFunction
],
dim
:
int
=
1
,
):
"""
Args:
x: Features to aggreagate. Tensor of shape `(batch, n_views, ..., dim)`.
w: Aggregation weights. Tensor of shape `(batch, n_views, ...,)`.
dim: the dimension along which to aggregate.
reduction_functions: The set of reduction functions.
Returns:
x_aggr: Aggregation of `x` to a tensor of shape `(batch, 1, ..., dim_aggregate)`.
"""
pooled_features
=
[]
mu
=
None
std
=
None
if
ReductionFunction
.
AVG
in
reduction_functions
:
# average pool
mu
=
_avg_reduction_function
(
x
,
w
,
dim
=
dim
)
pooled_features
.
append
(
mu
)
if
ReductionFunction
.
STD
in
reduction_functions
:
# standard-dev pool
std
=
_std_reduction_function
(
x
,
w
,
dim
=
dim
,
mu
=
mu
)
pooled_features
.
append
(
std
)
if
ReductionFunction
.
STD_AVG
in
reduction_functions
:
# average-of-standard-dev pool
stdavg
=
_std_avg_reduction_function
(
x
,
w
,
dim
=
dim
,
mu
=
mu
,
std
=
std
)
pooled_features
.
append
(
stdavg
)
if
ReductionFunction
.
MAX
in
reduction_functions
:
max_
=
_max_reduction_function
(
x
,
w
,
dim
=
dim
)
pooled_features
.
append
(
max_
)
# cat all results along the feature dimension (the last dim)
x_aggr
=
torch
.
cat
(
pooled_features
,
dim
=-
1
)
# zero out features that were all masked out
any_active
=
(
w
.
max
(
dim
=
dim
,
keepdim
=
True
).
values
>
1e-4
).
type_as
(
x_aggr
)
x_aggr
=
x_aggr
*
any_active
[...,
None
]
# some asserts to check that everything was done right
assert
torch
.
isfinite
(
x_aggr
).
all
()
assert
x_aggr
.
shape
[
1
]
==
1
return
x_aggr
def
_avg_reduction_function
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
dim
:
int
=
1
,
):
mu
=
wmean
(
x
,
w
,
dim
=
dim
,
eps
=
1e-2
)
return
mu
def
_std_reduction_function
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
dim
:
int
=
1
,
mu
:
Optional
[
torch
.
Tensor
]
=
None
,
# pre-computed mean
):
if
mu
is
None
:
mu
=
_avg_reduction_function
(
x
,
w
,
dim
=
dim
)
std
=
wmean
((
x
-
mu
)
**
2
,
w
,
dim
=
dim
,
eps
=
1e-2
).
clamp
(
1e-4
).
sqrt
()
# FIXME: somehow this is extremely heavy in mem?
return
std
def
_std_avg_reduction_function
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
dim
:
int
=
1
,
mu
:
Optional
[
torch
.
Tensor
]
=
None
,
# pre-computed mean
std
:
Optional
[
torch
.
Tensor
]
=
None
,
# pre-computed std
):
if
std
is
None
:
std
=
_std_reduction_function
(
x
,
w
,
dim
=
dim
,
mu
=
mu
)
stdmean
=
std
.
mean
(
dim
=-
1
,
keepdim
=
True
)
return
stdmean
def
_max_reduction_function
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
dim
:
int
=
1
,
big_M_factor
:
float
=
10.0
,
):
big_M
=
x
.
max
(
dim
=
dim
,
keepdim
=
True
).
values
.
abs
()
*
big_M_factor
max_
=
(
x
*
w
-
((
1
-
w
)
*
big_M
)).
max
(
dim
=
dim
,
keepdim
=
True
).
values
return
max_
pytorch3d/implicitron/models/view_pooling/view_sampling.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
pytorch3d.implicitron.tools.config
import
Configurable
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.renderer.utils
import
ndc_grid_sample
class
ViewSampler
(
Configurable
,
torch
.
nn
.
Module
):
"""
Implements sampling of image-based features at the 2d projections of a set
of 3D points.
Args:
masked_sampling: If `True`, the `sampled_masks` output of `self.forward`
contains the input `masks` sampled at the 2d projections. Otherwise,
all entries of `sampled_masks` are set to 1.
sampling_mode: Controls the mode of the `torch.nn.functional.grid_sample`
function used to interpolate the sampled feature tensors at the
locations of the 2d projections.
"""
masked_sampling
:
bool
=
False
sampling_mode
:
str
=
"bilinear"
def
__post_init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
*
,
# force kw args
pts
:
torch
.
Tensor
,
seq_id_pts
:
Union
[
List
[
int
],
List
[
str
],
torch
.
LongTensor
],
camera
:
CamerasBase
,
seq_id_camera
:
Union
[
List
[
int
],
List
[
str
],
torch
.
LongTensor
],
feats
:
Dict
[
str
,
torch
.
Tensor
],
masks
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Project each point cloud from a batch of point clouds to corresponding
input cameras and sample features at the 2D projection locations.
Args:
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
from which `pts` were extracted, or a list of string names.
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
corresponding to cameras in `camera`, or a list of string names.
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
masks: `[n_cameras x 1 x H x W]`, define valid image regions
for sampling `feats`.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` of shape `[pts_batch, n_cameras, n_pts, dim_i]`.
sampled_masks: A tensor with mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
"""
# convert sequence ids to long tensors
seq_id_pts
,
seq_id_camera
=
[
handle_seq_id
(
seq_id
,
pts
.
device
)
for
seq_id
in
[
seq_id_pts
,
seq_id_camera
]
]
if
self
.
masked_sampling
and
masks
is
None
:
raise
ValueError
(
"Masks have to be provided for `self.masked_sampling==True`"
)
# project pts to all cameras and sample feats from the locations of
# the 2D projections
sampled_feats_all_cams
,
sampled_masks_all_cams
=
project_points_and_sample
(
pts
,
feats
,
camera
,
masks
if
self
.
masked_sampling
else
None
,
sampling_mode
=
self
.
sampling_mode
,
)
# generate the mask that invalidates features sampled from
# non-corresponding cameras
camera_pts_mask
=
(
seq_id_camera
[
None
]
==
seq_id_pts
[:,
None
])[
...,
None
,
None
].
to
(
pts
)
# mask the sampled features and masks
sampled_feats
=
{
k
:
f
*
camera_pts_mask
for
k
,
f
in
sampled_feats_all_cams
.
items
()
}
sampled_masks
=
sampled_masks_all_cams
*
camera_pts_mask
return
sampled_feats
,
sampled_masks
def
project_points_and_sample
(
pts
:
torch
.
Tensor
,
feats
:
Dict
[
str
,
torch
.
Tensor
],
camera
:
CamerasBase
,
masks
:
Optional
[
torch
.
Tensor
],
eps
:
float
=
1e-2
,
sampling_mode
:
str
=
"bilinear"
,
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Project each point cloud from a batch of point clouds to all input cameras
and sample features at the 2D projection locations.
Args:
pts: `(pts_batch, n_pts, 3)` tensor containing a batch of 3D point clouds.
feats: A dict `{feat_i: feat_T_i}` of features to sample,
where each `feat_T_i` is a tensor of shape
`(n_cameras, feat_i_dim, feat_i_H, feat_i_W)`
of `feat_i_dim`-dimensional features extracted from `n_cameras`
source views.
camera: A batch of `n_cameras` cameras corresponding to their feature
tensors `feat_T_i` from `feats`.
masks: A tensor of shape `(n_cameras, 1, mask_H, mask_W)` denoting
valid locations for sampling.
eps: A small constant controlling the minimum depth of projections
of `pts` to avoid divisons by zero in the projection operation.
sampling_mode: Sampling mode of the grid sampler.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` is of shape
`(pts_batch, n_cameras, n_pts, feat_i_dim)`.
sampled_masks: A tensor with the mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
If `masks` is `None`, the returned `sampled_masks` will be
filled with 1s.
"""
n_cameras
=
camera
.
R
.
shape
[
0
]
pts_batch
=
pts
.
shape
[
0
]
n_pts
=
pts
.
shape
[
1
:
-
1
]
camera_rep
,
pts_rep
=
cameras_points_cartesian_product
(
camera
,
pts
)
# The eps here is super-important to avoid NaNs in backprop!
proj_rep
=
camera_rep
.
transform_points
(
pts_rep
.
reshape
(
n_cameras
*
pts_batch
,
-
1
,
3
),
eps
=
eps
)[...,
:
2
]
# [ pts1 in cam1, pts2 in cam1, pts3 in cam1,
# pts1 in cam2, pts2 in cam2, pts3 in cam2,
# pts1 in cam3, pts2 in cam3, pts3 in cam3 ]
# reshape for the grid sampler
sampling_grid_ndc
=
proj_rep
.
view
(
n_cameras
,
pts_batch
,
-
1
,
2
)
# [ [pts1 in cam1, pts2 in cam1, pts3 in cam1],
# [pts1 in cam2, pts2 in cam2, pts3 in cam2],
# [pts1 in cam3, pts2 in cam3, pts3 in cam3] ]
# n_cameras x pts_batch x n_pts x 2
# sample both feats
feats_sampled
=
{
k
:
ndc_grid_sample
(
f
,
sampling_grid_ndc
,
mode
=
sampling_mode
,
align_corners
=
False
,
)
.
permute
(
2
,
0
,
3
,
1
)
.
reshape
(
pts_batch
,
n_cameras
,
*
n_pts
,
-
1
)
for
k
,
f
in
feats
.
items
()
}
# {k: pts_batch x n_cameras x *n_pts x dim} for each feat type "k"
if
masks
is
not
None
:
# sample masks
masks_sampled
=
(
ndc_grid_sample
(
masks
,
sampling_grid_ndc
,
mode
=
sampling_mode
,
align_corners
=
False
,
)
.
permute
(
2
,
0
,
3
,
1
)
.
reshape
(
pts_batch
,
n_cameras
,
*
n_pts
,
1
)
)
else
:
masks_sampled
=
sampling_grid_ndc
.
new_ones
(
pts_batch
,
n_cameras
,
*
n_pts
,
1
)
return
feats_sampled
,
masks_sampled
def
handle_seq_id
(
seq_id
:
Union
[
torch
.
LongTensor
,
List
[
str
],
List
[
int
]],
device
,
)
->
torch
.
LongTensor
:
"""
Converts the input sequence id to a LongTensor.
Args:
seq_id: A sequence of sequence ids.
device: The target device of the output.
Returns
long_seq_id: `seq_id` converted to a `LongTensor` and moved to `device`.
"""
if
not
torch
.
is_tensor
(
seq_id
):
if
isinstance
(
seq_id
[
0
],
str
):
seq_id
=
[
hash
(
s
)
for
s
in
seq_id
]
seq_id
=
torch
.
tensor
(
seq_id
,
dtype
=
torch
.
long
,
device
=
device
)
return
seq_id
.
to
(
device
)
def
cameras_points_cartesian_product
(
camera
:
CamerasBase
,
pts
:
torch
.
Tensor
)
->
Tuple
[
CamerasBase
,
torch
.
Tensor
]:
"""
Generates all pairs of pairs of elements from 'camera' and 'pts' and returns
`camera_rep` and `pts_rep` such that:
```
camera_rep = [ pts_rep = [
camera[0] pts[0],
camera[0] pts[1],
camera[0] ...,
... pts[batch_pts-1],
camera[1] pts[0],
camera[1] pts[1],
camera[1] ...,
... pts[batch_pts-1],
... ...,
camera[n_cameras-1] pts[0],
camera[n_cameras-1] pts[1],
camera[n_cameras-1] ...,
... pts[batch_pts-1],
] ]
```
Args:
camera: A batch of `n_cameras` cameras.
pts: A batch of `batch_pts` points of shape `(batch_pts, ..., dim)`
Returns:
camera_rep: A batch of batch_pts*n_cameras cameras such that:
```
camera_rep = [
camera[0]
camera[0]
camera[0]
...
camera[1]
camera[1]
camera[1]
...
...
camera[n_cameras-1]
camera[n_cameras-1]
camera[n_cameras-1]
]
```
pts_rep: Repeated `pts` of shape `(batch_pts*n_cameras, ..., dim)`,
such that:
```
pts_rep = [
pts[0],
pts[1],
...,
pts[batch_pts-1],
pts[0],
pts[1],
...,
pts[batch_pts-1],
...,
pts[0],
pts[1],
...,
pts[batch_pts-1],
]
```
"""
n_cameras
=
camera
.
R
.
shape
[
0
]
batch_pts
=
pts
.
shape
[
0
]
pts_rep
=
pts
.
repeat
(
n_cameras
,
*
[
1
for
_
in
pts
.
shape
[
1
:]])
idx_cams
=
(
torch
.
arange
(
n_cameras
)[:,
None
]
.
expand
(
n_cameras
,
batch_pts
,
)
.
reshape
(
batch_pts
*
n_cameras
)
)
camera_rep
=
camera
[
idx_cams
]
return
camera_rep
,
pts_rep
pytorch3d/implicitron/third_party/hyperlayers.py
0 → 100644
View file @
cdd2142d
# a copy-paste from https://github.com/vsitzmann/scene-representation-networks/blob/master/hyperlayers.py
# fmt: off
# flake8: noqa
'''Pytorch implementations of hyper-network modules.
isort:skip_file
'''
import
functools
import
torch
import
torch.nn
as
nn
from
.
import
pytorch_prototyping
def
partialclass
(
cls
,
*
args
,
**
kwds
):
class
NewCls
(
cls
):
__init__
=
functools
.
partialmethod
(
cls
.
__init__
,
*
args
,
**
kwds
)
return
NewCls
class
LookupLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
num_objects
):
super
().
__init__
()
self
.
out_ch
=
out_ch
self
.
lookup_lin
=
LookupLinear
(
in_ch
,
out_ch
,
num_objects
=
num_objects
)
self
.
norm_nl
=
nn
.
Sequential
(
nn
.
LayerNorm
([
self
.
out_ch
],
elementwise_affine
=
False
),
nn
.
ReLU
(
inplace
=
True
)
)
def
forward
(
self
,
obj_idx
):
net
=
nn
.
Sequential
(
self
.
lookup_lin
(
obj_idx
),
self
.
norm_nl
)
return
net
class
LookupFC
(
nn
.
Module
):
def
__init__
(
self
,
hidden_ch
,
num_hidden_layers
,
num_objects
,
in_ch
,
out_ch
,
outermost_linear
=
False
,
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
.
append
(
LookupLayer
(
in_ch
=
in_ch
,
out_ch
=
hidden_ch
,
num_objects
=
num_objects
)
)
for
i
in
range
(
num_hidden_layers
):
self
.
layers
.
append
(
LookupLayer
(
in_ch
=
hidden_ch
,
out_ch
=
hidden_ch
,
num_objects
=
num_objects
)
)
if
outermost_linear
:
self
.
layers
.
append
(
LookupLinear
(
in_ch
=
hidden_ch
,
out_ch
=
out_ch
,
num_objects
=
num_objects
)
)
else
:
self
.
layers
.
append
(
LookupLayer
(
in_ch
=
hidden_ch
,
out_ch
=
out_ch
,
num_objects
=
num_objects
)
)
def
forward
(
self
,
obj_idx
):
net
=
[]
for
i
in
range
(
len
(
self
.
layers
)):
net
.
append
(
self
.
layers
[
i
](
obj_idx
))
return
nn
.
Sequential
(
*
net
)
class
LookupLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
num_objects
):
super
().
__init__
()
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
self
.
hypo_params
=
nn
.
Embedding
(
num_objects
,
in_ch
*
out_ch
+
out_ch
)
for
i
in
range
(
num_objects
):
nn
.
init
.
kaiming_normal_
(
self
.
hypo_params
.
weight
.
data
[
i
,
:
self
.
in_ch
*
self
.
out_ch
].
view
(
self
.
out_ch
,
self
.
in_ch
),
a
=
0.0
,
nonlinearity
=
"relu"
,
mode
=
"fan_in"
,
)
self
.
hypo_params
.
weight
.
data
[
i
,
self
.
in_ch
*
self
.
out_ch
:].
fill_
(
0.0
)
def
forward
(
self
,
obj_idx
):
hypo_params
=
self
.
hypo_params
(
obj_idx
)
# Indices explicit to catch erros in shape of output layer
weights
=
hypo_params
[...,
:
self
.
in_ch
*
self
.
out_ch
]
biases
=
hypo_params
[
...,
self
.
in_ch
*
self
.
out_ch
:
(
self
.
in_ch
*
self
.
out_ch
)
+
self
.
out_ch
]
biases
=
biases
.
view
(
*
(
biases
.
size
()[:
-
1
]),
1
,
self
.
out_ch
)
weights
=
weights
.
view
(
*
(
weights
.
size
()[:
-
1
]),
self
.
out_ch
,
self
.
in_ch
)
return
BatchLinear
(
weights
=
weights
,
biases
=
biases
)
class
HyperLayer
(
nn
.
Module
):
"""A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU."""
def
__init__
(
self
,
in_ch
,
out_ch
,
hyper_in_ch
,
hyper_num_hidden_layers
,
hyper_hidden_ch
):
super
().
__init__
()
self
.
hyper_linear
=
HyperLinear
(
in_ch
=
in_ch
,
out_ch
=
out_ch
,
hyper_in_ch
=
hyper_in_ch
,
hyper_num_hidden_layers
=
hyper_num_hidden_layers
,
hyper_hidden_ch
=
hyper_hidden_ch
,
)
self
.
norm_nl
=
nn
.
Sequential
(
nn
.
LayerNorm
([
out_ch
],
elementwise_affine
=
False
),
nn
.
ReLU
(
inplace
=
True
)
)
def
forward
(
self
,
hyper_input
):
"""
:param hyper_input: input to hypernetwork.
:return: nn.Module; predicted fully connected network.
"""
return
nn
.
Sequential
(
self
.
hyper_linear
(
hyper_input
),
self
.
norm_nl
)
class
HyperFC
(
nn
.
Module
):
"""Builds a hypernetwork that predicts a fully connected neural network."""
def
__init__
(
self
,
hyper_in_ch
,
hyper_num_hidden_layers
,
hyper_hidden_ch
,
hidden_ch
,
num_hidden_layers
,
in_ch
,
out_ch
,
outermost_linear
=
False
,
):
super
().
__init__
()
PreconfHyperLinear
=
partialclass
(
HyperLinear
,
hyper_in_ch
=
hyper_in_ch
,
hyper_num_hidden_layers
=
hyper_num_hidden_layers
,
hyper_hidden_ch
=
hyper_hidden_ch
,
)
PreconfHyperLayer
=
partialclass
(
HyperLayer
,
hyper_in_ch
=
hyper_in_ch
,
hyper_num_hidden_layers
=
hyper_num_hidden_layers
,
hyper_hidden_ch
=
hyper_hidden_ch
,
)
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
.
append
(
PreconfHyperLayer
(
in_ch
=
in_ch
,
out_ch
=
hidden_ch
))
for
i
in
range
(
num_hidden_layers
):
self
.
layers
.
append
(
PreconfHyperLayer
(
in_ch
=
hidden_ch
,
out_ch
=
hidden_ch
))
if
outermost_linear
:
self
.
layers
.
append
(
PreconfHyperLinear
(
in_ch
=
hidden_ch
,
out_ch
=
out_ch
))
else
:
self
.
layers
.
append
(
PreconfHyperLayer
(
in_ch
=
hidden_ch
,
out_ch
=
out_ch
))
def
forward
(
self
,
hyper_input
):
"""
:param hyper_input: Input to hypernetwork.
:return: nn.Module; Predicted fully connected neural network.
"""
net
=
[]
for
i
in
range
(
len
(
self
.
layers
)):
net
.
append
(
self
.
layers
[
i
](
hyper_input
))
return
nn
.
Sequential
(
*
net
)
class
BatchLinear
(
nn
.
Module
):
def
__init__
(
self
,
weights
,
biases
):
"""Implements a batch linear layer.
:param weights: Shape: (batch, out_ch, in_ch)
:param biases: Shape: (batch, 1, out_ch)
"""
super
().
__init__
()
self
.
weights
=
weights
self
.
biases
=
biases
def
__repr__
(
self
):
return
"BatchLinear(in_ch=%d, out_ch=%d)"
%
(
self
.
weights
.
shape
[
-
1
],
self
.
weights
.
shape
[
-
2
],
)
def
forward
(
self
,
input
):
output
=
input
.
matmul
(
self
.
weights
.
permute
(
*
[
i
for
i
in
range
(
len
(
self
.
weights
.
shape
)
-
2
)],
-
1
,
-
2
)
)
output
+=
self
.
biases
return
output
def
last_hyper_layer_init
(
m
)
->
None
:
if
type
(
m
)
==
nn
.
Linear
:
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0.0
,
nonlinearity
=
"relu"
,
mode
=
"fan_in"
)
# pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property.
m
.
weight
.
data
*=
1e-1
class
HyperLinear
(
nn
.
Module
):
"""A hypernetwork that predicts a single linear layer (weights & biases)."""
def
__init__
(
self
,
in_ch
,
out_ch
,
hyper_in_ch
,
hyper_num_hidden_layers
,
hyper_hidden_ch
):
super
().
__init__
()
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
self
.
hypo_params
=
pytorch_prototyping
.
FCBlock
(
in_features
=
hyper_in_ch
,
hidden_ch
=
hyper_hidden_ch
,
num_hidden_layers
=
hyper_num_hidden_layers
,
out_features
=
(
in_ch
*
out_ch
)
+
out_ch
,
outermost_linear
=
True
,
)
self
.
hypo_params
[
-
1
].
apply
(
last_hyper_layer_init
)
def
forward
(
self
,
hyper_input
):
hypo_params
=
self
.
hypo_params
(
hyper_input
)
# Indices explicit to catch erros in shape of output layer
weights
=
hypo_params
[...,
:
self
.
in_ch
*
self
.
out_ch
]
biases
=
hypo_params
[
...,
self
.
in_ch
*
self
.
out_ch
:
(
self
.
in_ch
*
self
.
out_ch
)
+
self
.
out_ch
]
biases
=
biases
.
view
(
*
(
biases
.
size
()[:
-
1
]),
1
,
self
.
out_ch
)
weights
=
weights
.
view
(
*
(
weights
.
size
()[:
-
1
]),
self
.
out_ch
,
self
.
in_ch
)
return
BatchLinear
(
weights
=
weights
,
biases
=
biases
)
pytorch3d/implicitron/third_party/pytorch_prototyping.py
0 → 100644
View file @
cdd2142d
# a copy-paste from https://raw.githubusercontent.com/vsitzmann/pytorch_prototyping/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py
# fmt: off
# flake8: noqa
'''A number of custom pytorch modules with sane defaults that I find useful for model prototyping.
isort:skip_file
'''
import
torch
import
torch.nn
as
nn
import
torchvision.utils
from
torch.nn
import
functional
as
F
class
FCLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
Linear
(
in_features
,
out_features
),
nn
.
LayerNorm
([
out_features
]),
nn
.
ReLU
(
inplace
=
True
),
)
def
forward
(
self
,
input
):
return
self
.
net
(
input
)
# From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02
class
LayerNormConv2d
(
nn
.
Module
):
def
__init__
(
self
,
num_features
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
()
self
.
num_features
=
num_features
self
.
affine
=
affine
self
.
eps
=
eps
if
self
.
affine
:
self
.
gamma
=
nn
.
Parameter
(
torch
.
Tensor
(
num_features
).
uniform_
())
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
num_features
))
def
forward
(
self
,
x
):
shape
=
[
-
1
]
+
[
1
]
*
(
x
.
dim
()
-
1
)
mean
=
x
.
view
(
x
.
size
(
0
),
-
1
).
mean
(
1
).
view
(
*
shape
)
std
=
x
.
view
(
x
.
size
(
0
),
-
1
).
std
(
1
).
view
(
*
shape
)
y
=
(
x
-
mean
)
/
(
std
+
self
.
eps
)
if
self
.
affine
:
shape
=
[
1
,
-
1
]
+
[
1
]
*
(
x
.
dim
()
-
2
)
y
=
self
.
gamma
.
view
(
*
shape
)
*
y
+
self
.
beta
.
view
(
*
shape
)
return
y
class
FCBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_ch
,
num_hidden_layers
,
in_features
,
out_features
,
outermost_linear
=
False
,
):
super
().
__init__
()
self
.
net
=
[]
self
.
net
.
append
(
FCLayer
(
in_features
=
in_features
,
out_features
=
hidden_ch
))
for
i
in
range
(
num_hidden_layers
):
self
.
net
.
append
(
FCLayer
(
in_features
=
hidden_ch
,
out_features
=
hidden_ch
))
if
outermost_linear
:
self
.
net
.
append
(
nn
.
Linear
(
in_features
=
hidden_ch
,
out_features
=
out_features
))
else
:
self
.
net
.
append
(
FCLayer
(
in_features
=
hidden_ch
,
out_features
=
out_features
))
self
.
net
=
nn
.
Sequential
(
*
self
.
net
)
self
.
net
.
apply
(
self
.
init_weights
)
def
__getitem__
(
self
,
item
):
return
self
.
net
[
item
]
def
init_weights
(
self
,
m
):
if
type
(
m
)
==
nn
.
Linear
:
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0.0
,
nonlinearity
=
"relu"
,
mode
=
"fan_in"
)
def
forward
(
self
,
input
):
return
self
.
net
(
input
)
class
DownBlock3D
(
nn
.
Module
):
"""A 3D convolutional downsampling block."""
def
__init__
(
self
,
in_channels
,
out_channels
,
norm
=
nn
.
BatchNorm3d
):
super
().
__init__
()
self
.
net
=
[
nn
.
ReplicationPad3d
(
1
),
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
=
4
,
padding
=
0
,
stride
=
2
,
bias
=
False
if
norm
is
not
None
else
True
,
),
]
if
norm
is
not
None
:
self
.
net
+=
[
norm
(
out_channels
,
affine
=
True
)]
self
.
net
+=
[
nn
.
LeakyReLU
(
0.2
,
True
)]
self
.
net
=
nn
.
Sequential
(
*
self
.
net
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
UpBlock3D
(
nn
.
Module
):
"""A 3D convolutional upsampling block."""
def
__init__
(
self
,
in_channels
,
out_channels
,
norm
=
nn
.
BatchNorm3d
):
super
().
__init__
()
self
.
net
=
[
nn
.
ConvTranspose3d
(
in_channels
,
out_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
,
bias
=
False
if
norm
is
not
None
else
True
,
),
]
if
norm
is
not
None
:
self
.
net
+=
[
norm
(
out_channels
,
affine
=
True
)]
self
.
net
+=
[
nn
.
ReLU
(
True
)]
self
.
net
=
nn
.
Sequential
(
*
self
.
net
)
def
forward
(
self
,
x
,
skipped
=
None
):
if
skipped
is
not
None
:
input
=
torch
.
cat
([
skipped
,
x
],
dim
=
1
)
else
:
input
=
x
return
self
.
net
(
input
)
class
Conv3dSame
(
torch
.
nn
.
Module
):
"""3D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
bias
=
True
,
padding_layer
=
nn
.
ReplicationPad3d
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super
().
__init__
()
ka
=
kernel_size
//
2
kb
=
ka
-
1
if
kernel_size
%
2
==
0
else
ka
self
.
net
=
nn
.
Sequential
(
padding_layer
((
ka
,
kb
,
ka
,
kb
,
ka
,
kb
)),
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
bias
,
stride
=
1
),
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
Conv2dSame
(
torch
.
nn
.
Module
):
"""2D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
bias
=
True
,
padding_layer
=
nn
.
ReflectionPad2d
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super
().
__init__
()
ka
=
kernel_size
//
2
kb
=
ka
-
1
if
kernel_size
%
2
==
0
else
ka
self
.
net
=
nn
.
Sequential
(
padding_layer
((
ka
,
kb
,
ka
,
kb
)),
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
bias
,
stride
=
1
),
)
self
.
weight
=
self
.
net
[
1
].
weight
self
.
bias
=
self
.
net
[
1
].
bias
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
UpBlock
(
nn
.
Module
):
"""A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with
reasonable defaults. (LeakyReLU, kernel size multiple of stride)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
post_conv
=
True
,
use_dropout
=
False
,
dropout_prob
=
0.1
,
norm
=
nn
.
BatchNorm2d
,
upsampling_mode
=
"transpose"
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param post_conv: Whether to have another convolutional layer after the upsampling layer.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param upsampling_mode: Which upsampling mode:
transpose: Upsampling with stride-2, kernel size 4 transpose convolutions.
bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer.
nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer.
shuffle: Feature map is upsampled with pixel shuffling, then a conv layer.
"""
super
().
__init__
()
net
=
list
()
if
upsampling_mode
==
"transpose"
:
net
+=
[
nn
.
ConvTranspose2d
(
in_channels
,
out_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
,
bias
=
True
if
norm
is
None
else
False
,
)
]
elif
upsampling_mode
==
"bilinear"
:
net
+=
[
nn
.
UpsamplingBilinear2d
(
scale_factor
=
2
)]
net
+=
[
Conv2dSame
(
in_channels
,
out_channels
,
kernel_size
=
3
,
bias
=
True
if
norm
is
None
else
False
,
)
]
elif
upsampling_mode
==
"nearest"
:
net
+=
[
nn
.
UpsamplingNearest2d
(
scale_factor
=
2
)]
net
+=
[
Conv2dSame
(
in_channels
,
out_channels
,
kernel_size
=
3
,
bias
=
True
if
norm
is
None
else
False
,
)
]
elif
upsampling_mode
==
"shuffle"
:
net
+=
[
nn
.
PixelShuffle
(
upscale_factor
=
2
)]
net
+=
[
Conv2dSame
(
in_channels
//
4
,
out_channels
,
kernel_size
=
3
,
bias
=
True
if
norm
is
None
else
False
,
)
]
else
:
raise
ValueError
(
"Unknown upsampling mode!"
)
if
norm
is
not
None
:
net
+=
[
norm
(
out_channels
,
affine
=
True
)]
net
+=
[
nn
.
ReLU
(
True
)]
if
use_dropout
:
net
+=
[
nn
.
Dropout2d
(
dropout_prob
,
False
)]
if
post_conv
:
net
+=
[
Conv2dSame
(
out_channels
,
out_channels
,
kernel_size
=
3
,
bias
=
True
if
norm
is
None
else
False
,
)
]
if
norm
is
not
None
:
net
+=
[
norm
(
out_channels
,
affine
=
True
)]
net
+=
[
nn
.
ReLU
(
True
)]
if
use_dropout
:
net
+=
[
nn
.
Dropout2d
(
0.1
,
False
)]
self
.
net
=
nn
.
Sequential
(
*
net
)
def
forward
(
self
,
x
,
skipped
=
None
):
if
skipped
is
not
None
:
input
=
torch
.
cat
([
skipped
,
x
],
dim
=
1
)
else
:
input
=
x
return
self
.
net
(
input
)
class
DownBlock
(
nn
.
Module
):
"""A 2D-conv downsampling block following best practices / with reasonable defaults
(LeakyReLU, kernel size multiple of stride)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
prep_conv
=
True
,
middle_channels
=
None
,
use_dropout
=
False
,
dropout_prob
=
0.1
,
norm
=
nn
.
BatchNorm2d
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param prep_conv: Whether to have another convolutional layer before the downsampling layer.
:param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling
convs.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
"""
super
().
__init__
()
if
middle_channels
is
None
:
middle_channels
=
in_channels
net
=
list
()
if
prep_conv
:
net
+=
[
nn
.
ReflectionPad2d
(
1
),
nn
.
Conv2d
(
in_channels
,
middle_channels
,
kernel_size
=
3
,
padding
=
0
,
stride
=
1
,
bias
=
True
if
norm
is
None
else
False
,
),
]
if
norm
is
not
None
:
net
+=
[
norm
(
middle_channels
,
affine
=
True
)]
net
+=
[
nn
.
LeakyReLU
(
0.2
,
True
)]
if
use_dropout
:
net
+=
[
nn
.
Dropout2d
(
dropout_prob
,
False
)]
net
+=
[
nn
.
ReflectionPad2d
(
1
),
nn
.
Conv2d
(
middle_channels
,
out_channels
,
kernel_size
=
4
,
padding
=
0
,
stride
=
2
,
bias
=
True
if
norm
is
None
else
False
,
),
]
if
norm
is
not
None
:
net
+=
[
norm
(
out_channels
,
affine
=
True
)]
net
+=
[
nn
.
LeakyReLU
(
0.2
,
True
)]
if
use_dropout
:
net
+=
[
nn
.
Dropout2d
(
dropout_prob
,
False
)]
self
.
net
=
nn
.
Sequential
(
*
net
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
Unet3d
(
nn
.
Module
):
"""A 3d-Unet implementation with sane defaults."""
def
__init__
(
self
,
in_channels
,
out_channels
,
nf0
,
num_down
,
max_channels
,
norm
=
nn
.
BatchNorm3d
,
outermost_linear
=
False
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super
().
__init__
()
assert
num_down
>
0
,
"Need at least one downsampling layer in UNet3d."
# Define the in block
self
.
in_layer
=
[
Conv3dSame
(
in_channels
,
nf0
,
kernel_size
=
3
,
bias
=
False
)]
if
norm
is
not
None
:
self
.
in_layer
+=
[
norm
(
nf0
,
affine
=
True
)]
self
.
in_layer
+=
[
nn
.
LeakyReLU
(
0.2
,
True
)]
self
.
in_layer
=
nn
.
Sequential
(
*
self
.
in_layer
)
# Define the center UNet block. The feature map has height and width 1 --> no batchnorm.
self
.
unet_block
=
UnetSkipConnectionBlock3d
(
int
(
min
(
2
**
(
num_down
-
1
)
*
nf0
,
max_channels
)),
int
(
min
(
2
**
(
num_down
-
1
)
*
nf0
,
max_channels
)),
norm
=
None
,
)
for
i
in
list
(
range
(
0
,
num_down
-
1
))[::
-
1
]:
self
.
unet_block
=
UnetSkipConnectionBlock3d
(
int
(
min
(
2
**
i
*
nf0
,
max_channels
)),
int
(
min
(
2
**
(
i
+
1
)
*
nf0
,
max_channels
)),
submodule
=
self
.
unet_block
,
norm
=
norm
,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self
.
out_layer
=
[
Conv3dSame
(
2
*
nf0
,
out_channels
,
kernel_size
=
3
,
bias
=
outermost_linear
)
]
if
not
outermost_linear
:
if
norm
is
not
None
:
self
.
out_layer
+=
[
norm
(
out_channels
,
affine
=
True
)]
self
.
out_layer
+=
[
nn
.
ReLU
(
True
)]
self
.
out_layer
=
nn
.
Sequential
(
*
self
.
out_layer
)
def
forward
(
self
,
x
):
in_layer
=
self
.
in_layer
(
x
)
unet
=
self
.
unet_block
(
in_layer
)
out_layer
=
self
.
out_layer
(
unet
)
return
out_layer
class
UnetSkipConnectionBlock3d
(
nn
.
Module
):
"""Helper class for building a 3D unet."""
def
__init__
(
self
,
outer_nc
,
inner_nc
,
norm
=
nn
.
BatchNorm3d
,
submodule
=
None
):
super
().
__init__
()
if
submodule
is
None
:
model
=
[
DownBlock3D
(
outer_nc
,
inner_nc
,
norm
=
norm
),
UpBlock3D
(
inner_nc
,
outer_nc
,
norm
=
norm
),
]
else
:
model
=
[
DownBlock3D
(
outer_nc
,
inner_nc
,
norm
=
norm
),
submodule
,
UpBlock3D
(
2
*
inner_nc
,
outer_nc
,
norm
=
norm
),
]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
forward_passed
=
self
.
model
(
x
)
return
torch
.
cat
([
x
,
forward_passed
],
1
)
class
UnetSkipConnectionBlock
(
nn
.
Module
):
"""Helper class for building a 2D unet."""
def
__init__
(
self
,
outer_nc
,
inner_nc
,
upsampling_mode
,
norm
=
nn
.
BatchNorm2d
,
submodule
=
None
,
use_dropout
=
False
,
dropout_prob
=
0.1
,
):
super
().
__init__
()
if
submodule
is
None
:
model
=
[
DownBlock
(
outer_nc
,
inner_nc
,
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
norm
,
),
UpBlock
(
inner_nc
,
outer_nc
,
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
norm
,
upsampling_mode
=
upsampling_mode
,
),
]
else
:
model
=
[
DownBlock
(
outer_nc
,
inner_nc
,
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
norm
,
),
submodule
,
UpBlock
(
2
*
inner_nc
,
outer_nc
,
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
norm
,
upsampling_mode
=
upsampling_mode
,
),
]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
forward_passed
=
self
.
model
(
x
)
return
torch
.
cat
([
x
,
forward_passed
],
1
)
class
Unet
(
nn
.
Module
):
"""A 2d-Unet implementation with sane defaults."""
def
__init__
(
self
,
in_channels
,
out_channels
,
nf0
,
num_down
,
max_channels
,
use_dropout
,
upsampling_mode
=
"transpose"
,
dropout_prob
=
0.1
,
norm
=
nn
.
BatchNorm2d
,
outermost_linear
=
False
,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param use_dropout: Whether to use dropout or no.
:param dropout_prob: Dropout probability if use_dropout=True.
:param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation.
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super
().
__init__
()
assert
num_down
>
0
,
"Need at least one downsampling layer in UNet."
# Define the in block
self
.
in_layer
=
[
Conv2dSame
(
in_channels
,
nf0
,
kernel_size
=
3
,
bias
=
True
if
norm
is
None
else
False
)
]
if
norm
is
not
None
:
self
.
in_layer
+=
[
norm
(
nf0
,
affine
=
True
)]
self
.
in_layer
+=
[
nn
.
LeakyReLU
(
0.2
,
True
)]
if
use_dropout
:
self
.
in_layer
+=
[
nn
.
Dropout2d
(
dropout_prob
)]
self
.
in_layer
=
nn
.
Sequential
(
*
self
.
in_layer
)
# Define the center UNet block
self
.
unet_block
=
UnetSkipConnectionBlock
(
min
(
2
**
(
num_down
-
1
)
*
nf0
,
max_channels
),
min
(
2
**
(
num_down
-
1
)
*
nf0
,
max_channels
),
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
None
,
# Innermost has no norm (spatial dimension 1)
upsampling_mode
=
upsampling_mode
,
)
for
i
in
list
(
range
(
0
,
num_down
-
1
))[::
-
1
]:
self
.
unet_block
=
UnetSkipConnectionBlock
(
min
(
2
**
i
*
nf0
,
max_channels
),
min
(
2
**
(
i
+
1
)
*
nf0
,
max_channels
),
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
submodule
=
self
.
unet_block
,
norm
=
norm
,
upsampling_mode
=
upsampling_mode
,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self
.
out_layer
=
[
Conv2dSame
(
2
*
nf0
,
out_channels
,
kernel_size
=
3
,
bias
=
outermost_linear
or
(
norm
is
None
),
)
]
if
not
outermost_linear
:
if
norm
is
not
None
:
self
.
out_layer
+=
[
norm
(
out_channels
,
affine
=
True
)]
self
.
out_layer
+=
[
nn
.
ReLU
(
True
)]
if
use_dropout
:
self
.
out_layer
+=
[
nn
.
Dropout2d
(
dropout_prob
)]
self
.
out_layer
=
nn
.
Sequential
(
*
self
.
out_layer
)
self
.
out_layer_weight
=
self
.
out_layer
[
0
].
weight
def
forward
(
self
,
x
):
in_layer
=
self
.
in_layer
(
x
)
unet
=
self
.
unet_block
(
in_layer
)
out_layer
=
self
.
out_layer
(
unet
)
return
out_layer
class
Identity
(
nn
.
Module
):
"""Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list."""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
):
return
input
class
DownsamplingNet
(
nn
.
Module
):
"""A subnetwork that downsamples a 2D feature map with strided convolutions."""
def
__init__
(
self
,
per_layer_out_ch
,
in_channels
,
use_dropout
,
dropout_prob
=
0.1
,
last_layer_one
=
False
,
norm
=
nn
.
BatchNorm2d
,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of downsampling steps (each step dowsamples by factor of 2.)
:param in_channels: Number of input channels.
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case,
the last layer will not have batchnorm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super
().
__init__
()
if
not
len
(
per_layer_out_ch
):
self
.
downs
=
Identity
()
else
:
self
.
downs
=
list
()
self
.
downs
.
append
(
DownBlock
(
in_channels
,
per_layer_out_ch
[
0
],
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
middle_channels
=
per_layer_out_ch
[
0
],
norm
=
norm
,
)
)
for
i
in
range
(
0
,
len
(
per_layer_out_ch
)
-
1
):
if
last_layer_one
and
(
i
==
len
(
per_layer_out_ch
)
-
2
):
norm
=
None
self
.
downs
.
append
(
DownBlock
(
per_layer_out_ch
[
i
],
per_layer_out_ch
[
i
+
1
],
dropout_prob
=
dropout_prob
,
use_dropout
=
use_dropout
,
norm
=
norm
,
)
)
self
.
downs
=
nn
.
Sequential
(
*
self
.
downs
)
def
forward
(
self
,
input
):
return
self
.
downs
(
input
)
class
UpsamplingNet
(
nn
.
Module
):
"""A subnetwork that upsamples a 2D feature map with a variety of upsampling options."""
def
__init__
(
self
,
per_layer_out_ch
,
in_channels
,
upsampling_mode
,
use_dropout
,
dropout_prob
=
0.1
,
first_layer_one
=
False
,
norm
=
nn
.
BatchNorm2d
,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of upsampling steps (each step upsamples by factor of 2.)
:param in_channels: Number of input channels.
:param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock"
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case,
the first layer will not have a norm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super
().
__init__
()
if
not
len
(
per_layer_out_ch
):
self
.
ups
=
Identity
()
else
:
self
.
ups
=
list
()
self
.
ups
.
append
(
UpBlock
(
in_channels
,
per_layer_out_ch
[
0
],
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
None
if
first_layer_one
else
norm
,
upsampling_mode
=
upsampling_mode
,
)
)
for
i
in
range
(
0
,
len
(
per_layer_out_ch
)
-
1
):
self
.
ups
.
append
(
UpBlock
(
per_layer_out_ch
[
i
],
per_layer_out_ch
[
i
+
1
],
use_dropout
=
use_dropout
,
dropout_prob
=
dropout_prob
,
norm
=
norm
,
upsampling_mode
=
upsampling_mode
,
)
)
self
.
ups
=
nn
.
Sequential
(
*
self
.
ups
)
def
forward
(
self
,
input
):
return
self
.
ups
(
input
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment