Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pytorch3d
Commits
cdd2142d
Unverified
Commit
cdd2142d
authored
Mar 21, 2022
by
Jeremy Reizenstein
Committed by
GitHub
Mar 21, 2022
Browse files
implicitron v0 (#1133)
Co-authored-by:
Jeremy Francis Reizenstein
<
bottler@users.noreply.github.com
>
parent
0e377c68
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3732 additions
and
0 deletions
+3732
-0
pytorch3d/implicitron/tools/__init__.py
pytorch3d/implicitron/tools/__init__.py
+0
-0
pytorch3d/implicitron/tools/camera_utils.py
pytorch3d/implicitron/tools/camera_utils.py
+142
-0
pytorch3d/implicitron/tools/circle_fitting.py
pytorch3d/implicitron/tools/circle_fitting.py
+231
-0
pytorch3d/implicitron/tools/config.py
pytorch3d/implicitron/tools/config.py
+714
-0
pytorch3d/implicitron/tools/depth_cleanup.py
pytorch3d/implicitron/tools/depth_cleanup.py
+113
-0
pytorch3d/implicitron/tools/eval_video_trajectory.py
pytorch3d/implicitron/tools/eval_video_trajectory.py
+226
-0
pytorch3d/implicitron/tools/image_utils.py
pytorch3d/implicitron/tools/image_utils.py
+53
-0
pytorch3d/implicitron/tools/metric_utils.py
pytorch3d/implicitron/tools/metric_utils.py
+231
-0
pytorch3d/implicitron/tools/model_io.py
pytorch3d/implicitron/tools/model_io.py
+163
-0
pytorch3d/implicitron/tools/point_cloud_utils.py
pytorch3d/implicitron/tools/point_cloud_utils.py
+168
-0
pytorch3d/implicitron/tools/rasterize_mc.py
pytorch3d/implicitron/tools/rasterize_mc.py
+63
-0
pytorch3d/implicitron/tools/stats.py
pytorch3d/implicitron/tools/stats.py
+491
-0
pytorch3d/implicitron/tools/utils.py
pytorch3d/implicitron/tools/utils.py
+183
-0
pytorch3d/implicitron/tools/video_writer.py
pytorch3d/implicitron/tools/video_writer.py
+149
-0
pytorch3d/implicitron/tools/vis_utils.py
pytorch3d/implicitron/tools/vis_utils.py
+172
-0
tests/implicitron/__init__.py
tests/implicitron/__init__.py
+5
-0
tests/implicitron/common_resources.py
tests/implicitron/common_resources.py
+114
-0
tests/implicitron/data/overrides.yaml
tests/implicitron/data/overrides.yaml
+122
-0
tests/implicitron/test_batch_sampler.py
tests/implicitron/test_batch_sampler.py
+215
-0
tests/implicitron/test_circle_fitting.py
tests/implicitron/test_circle_fitting.py
+177
-0
No files found.
pytorch3d/implicitron/tools/__init__.py
0 → 100644
View file @
cdd2142d
pytorch3d/implicitron/tools/camera_utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# TODO: all this potentially goes to PyTorch3D
import
math
from
typing
import
Tuple
import
pytorch3d
as
pt3d
import
torch
from
pytorch3d.renderer.cameras
import
CamerasBase
def
jitter_extrinsics
(
R
:
torch
.
Tensor
,
T
:
torch
.
Tensor
,
max_angle
:
float
=
(
math
.
pi
*
2.0
),
translation_std
:
float
=
1.0
,
scale_std
:
float
=
0.3
,
):
"""
Jitter the extrinsic camera parameters `R` and `T` with a random similarity
transformation. The transformation rotates by a random angle between [0, max_angle];
scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is
a random sample from a normal distrubtion with zero mean and variance scale_std;
and translates by a 3D offset sampled from N(0, translation_std).
"""
assert
all
(
x
>=
0.0
for
x
in
(
max_angle
,
translation_std
,
scale_std
))
N
=
R
.
shape
[
0
]
R_jit
=
pt3d
.
transforms
.
random_rotations
(
1
,
device
=
R
.
device
)
R_jit
=
pt3d
.
transforms
.
so3_exponential_map
(
pt3d
.
transforms
.
so3_log_map
(
R_jit
)
*
max_angle
)
T_jit
=
torch
.
randn_like
(
R_jit
[:
1
,
:,
0
])
*
translation_std
rigid_transform
=
pt3d
.
ops
.
eyes
(
dim
=
4
,
N
=
N
,
device
=
R
.
device
)
rigid_transform
[:,
:
3
,
:
3
]
=
R_jit
.
expand
(
N
,
3
,
3
)
rigid_transform
[:,
3
,
:
3
]
=
T_jit
.
expand
(
N
,
3
)
scale_jit
=
torch
.
exp
(
torch
.
randn_like
(
T_jit
[:,
0
])
*
scale_std
).
expand
(
N
)
return
apply_camera_alignment
(
R
,
T
,
rigid_transform
,
scale_jit
)
def
apply_camera_alignment
(
R
:
torch
.
Tensor
,
T
:
torch
.
Tensor
,
rigid_transform
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
"""
Args:
R: Camera rotation matrix of shape (N, 3, 3).
T: Camera translation of shape (N, 3).
rigid_transform: A tensor of shape (N, 4, 4) representing a batch of
N 4x4 tensors that map the scene pointcloud from misaligned coords
to the aligned space.
scale: A list of N scaling factors. A tensor of shape (N,)
Returns:
R_aligned: The aligned rotations R.
T_aligned: The aligned translations T.
"""
R_rigid
=
rigid_transform
[:,
:
3
,
:
3
]
T_rigid
=
rigid_transform
[:,
3
:,
:
3
]
R_aligned
=
R_rigid
.
permute
(
0
,
2
,
1
).
bmm
(
R
)
T_aligned
=
scale
[:,
None
]
*
(
T
-
(
T_rigid
@
R_aligned
)[:,
0
])
return
R_aligned
,
T_aligned
def
get_min_max_depth_bounds
(
cameras
,
scene_center
,
scene_extent
):
"""
Estimate near/far depth plane as:
near = dist(cam_center, self.scene_center) - self.scene_extent
far = dist(cam_center, self.scene_center) + self.scene_extent
"""
cam_center
=
cameras
.
get_camera_center
()
center_dist
=
(
((
cam_center
-
scene_center
.
to
(
cameras
.
R
)[
None
])
**
2
)
.
sum
(
dim
=-
1
)
.
clamp
(
0.001
)
.
sqrt
()
)
center_dist
=
center_dist
.
clamp
(
scene_extent
+
1e-3
)
min_depth
=
center_dist
-
scene_extent
max_depth
=
center_dist
+
scene_extent
return
min_depth
,
max_depth
def
volumetric_camera_overlaps
(
cameras
:
CamerasBase
,
scene_extent
:
float
=
8.0
,
scene_center
:
Tuple
[
float
,
float
,
float
]
=
(
0.0
,
0.0
,
0.0
),
resol
:
int
=
16
,
weigh_by_ray_angle
:
bool
=
True
,
):
"""
Compute the overlaps between viewing frustrums of all pairs of cameras
in `cameras`.
"""
device
=
cameras
.
device
ba
=
cameras
.
R
.
shape
[
0
]
n_vox
=
int
(
resol
**
3
)
grid
=
pt3d
.
structures
.
Volumes
(
densities
=
torch
.
zeros
([
1
,
1
,
resol
,
resol
,
resol
],
device
=
device
),
volume_translation
=-
torch
.
FloatTensor
(
scene_center
)[
None
].
to
(
device
),
voxel_size
=
2.0
*
scene_extent
/
resol
,
).
get_coord_grid
(
world_coordinates
=
True
)
grid
=
grid
.
view
(
1
,
n_vox
,
3
).
expand
(
ba
,
n_vox
,
3
)
gridp
=
cameras
.
transform_points
(
grid
,
eps
=
1e-2
)
proj_in_camera
=
(
torch
.
prod
((
gridp
[...,
:
2
].
abs
()
<=
1.0
),
dim
=-
1
)
*
(
gridp
[...,
2
]
>
0.0
).
float
()
)
# ba x n_vox
if
weigh_by_ray_angle
:
rays
=
torch
.
nn
.
functional
.
normalize
(
grid
-
cameras
.
get_camera_center
()[:,
None
],
dim
=-
1
)
rays_masked
=
rays
*
proj_in_camera
[...,
None
]
# - slow and readable:
# inter = torch.zeros(ba, ba)
# for i1 in range(ba):
# for i2 in range(ba):
# inter[i1, i2] = (
# 1 + (rays_masked[i1] * rays_masked[i2]
# ).sum(dim=-1)).sum()
# - fast:
rays_masked
=
rays_masked
.
view
(
ba
,
n_vox
*
3
)
inter
=
n_vox
+
(
rays_masked
@
rays_masked
.
t
())
else
:
inter
=
proj_in_camera
@
proj_in_camera
.
t
()
mass
=
torch
.
diag
(
inter
)
iou
=
inter
/
(
mass
[:,
None
]
+
mass
[
None
,
:]
-
inter
).
clamp
(
0.1
)
return
iou
pytorch3d/implicitron/tools/circle_fitting.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
warnings
from
dataclasses
import
dataclass
from
math
import
pi
from
typing
import
Optional
import
torch
from
pytorch3d.common.compat
import
eigh
,
lstsq
def
_get_rotation_to_best_fit_xy
(
points
:
torch
.
Tensor
,
centroid
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Returns a rotation r such that points @ r has a best fit plane
parallel to the xy plane
Args:
points: (N, 3) tensor of points in 3D
centroid: (3,) their centroid
Returns:
(3,3) tensor rotation matrix
"""
points_centered
=
points
-
centroid
[
None
]
return
eigh
(
points_centered
.
t
()
@
points_centered
)[
1
][:,
[
1
,
2
,
0
]]
def
_signed_area
(
path
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculates the signed area / Lévy area of a 2D path. If the path is closed,
i.e. ends where it starts, this is the integral of the winding number over
the whole plane. If not, consider a closed path made by adding a straight
line from the end to the start; the signed area is the integral of the
winding number (also over the plane) with respect to that closed path.
If this number is positive, it indicates in some sense that the path
turns anticlockwise more than clockwise, and vice versa.
Args:
path: N x 2 tensor of points.
Returns:
signed area, shape ()
"""
# This calculation is a sum of areas of triangles of the form
# (path[0], path[i], path[i+1]), where each triangle is half a
# parallelogram.
x
,
y
=
(
path
[
1
:]
-
path
[:
1
]).
unbind
(
1
)
return
(
y
[
1
:]
*
x
[:
-
1
]
-
x
[
1
:]
*
y
[:
-
1
]).
sum
()
*
0.5
@
dataclass
(
frozen
=
True
)
class
Circle2D
:
"""
Contains details of a circle in a plane.
Members
center: tensor shape (2,)
radius: tensor shape ()
generated_points: points around the circle, shape (n_points, 2)
"""
center
:
torch
.
Tensor
radius
:
torch
.
Tensor
generated_points
:
torch
.
Tensor
def
fit_circle_in_2d
(
points2d
,
*
,
n_points
:
int
=
0
,
angles
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Circle2D
:
"""
Simple best fitting of a circle to 2D points. In particular, the circle which
minimizes the sum of the squares of the squared-distances to the circle.
Finds (a,b) and r to minimize the sum of squares (over the x,y pairs) of
r**2 - [(x-a)**2+(y-b)**2]
i.e.
(2*a)*x + (2*b)*y + (r**2 - a**2 - b**2)*1 - (x**2 + y**2)
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to match the movement of points in points2d, as judged by its
signed area. If `angles` are provided, then n_points is ignored, and points
along the circle at the given angles are returned, with the starting point
and direction as before.
(Note that `generated_points` is affected by the order of the points in
points2d, but the other outputs are not.)
Args:
points2d: N x 2 tensor of 2D points
n_points: number of points to generate on the circle, if angles not given
angles: optional angles in radians of points to generate.
Returns:
Circle2D object
"""
design
=
torch
.
cat
([
points2d
,
torch
.
ones_like
(
points2d
[:,
:
1
])],
dim
=
1
)
rhs
=
(
points2d
**
2
).
sum
(
1
)
n_provided
=
points2d
.
shape
[
0
]
if
n_provided
<
3
:
raise
ValueError
(
f
"
{
n_provided
}
points are not enough to determine a circle"
)
solution
=
lstsq
(
design
,
rhs
)
center
=
solution
[:
2
]
/
2
radius
=
torch
.
sqrt
(
solution
[
2
]
+
(
center
**
2
).
sum
())
if
n_points
>
0
:
if
angles
is
not
None
:
warnings
.
warn
(
"n_points ignored because angles provided"
)
else
:
angles
=
torch
.
linspace
(
0
,
2
*
pi
,
n_points
,
device
=
points2d
.
device
)
if
angles
is
not
None
:
initial_direction_xy
=
(
points2d
[
0
]
-
center
).
unbind
()
initial_angle
=
torch
.
atan2
(
initial_direction_xy
[
1
],
initial_direction_xy
[
0
])
with
torch
.
no_grad
():
anticlockwise
=
_signed_area
(
points2d
)
>
0
if
anticlockwise
:
use_angles
=
initial_angle
+
angles
else
:
use_angles
=
initial_angle
-
angles
generated_points
=
center
[
None
]
+
radius
*
torch
.
stack
(
[
torch
.
cos
(
use_angles
),
torch
.
sin
(
use_angles
)],
dim
=-
1
)
else
:
generated_points
=
points2d
.
new_zeros
(
0
,
2
)
return
Circle2D
(
center
=
center
,
radius
=
radius
,
generated_points
=
generated_points
)
@
dataclass
(
frozen
=
True
)
class
Circle3D
:
"""
Contains details of a circle in 3D.
Members
center: tensor shape (3,)
radius: tensor shape ()
normal: tensor shape (3,)
generated_points: points around the circle, shape (n_points, 3)
"""
center
:
torch
.
Tensor
radius
:
torch
.
Tensor
normal
:
torch
.
Tensor
generated_points
:
torch
.
Tensor
def
fit_circle_in_3d
(
points
,
*
,
n_points
:
int
=
0
,
angles
:
Optional
[
torch
.
Tensor
]
=
None
,
offset
:
Optional
[
torch
.
Tensor
]
=
None
,
up
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Circle3D
:
"""
Simple best fit circle to 3D points. Uses circle_2d in the
least-squares best fit plane.
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to be match the movement of points. If angles is provided, then n_points
is ignored, and points along the circle at the given angles are returned,
with the starting point and direction as before.
Further, an offset can be given to add to the generated points; this is
interpreted in a rotated coordinate system where (0, 0, 1) is normal to the
circle, specifically the normal which is approximately in the direction of a
given `up` vector. The remaining rotation is disambiguated in an unspecified
but deterministic way.
(Note that `generated_points` is affected by the order of the points in
points, but the other outputs are not.)
Args:
points2d: N x 3 tensor of 3D points
n_points: number of points to generate on the circle
angles: optional angles in radians of points to generate.
offset: optional tensor (3,), a displacement expressed in a "canonical"
coordinate system to add to the generated points.
up: optional tensor (3,), a vector which helps define the
"canonical" coordinate system for interpretting `offset`.
Required if offset is used.
Returns:
Circle3D object
"""
centroid
=
points
.
mean
(
0
)
r
=
_get_rotation_to_best_fit_xy
(
points
,
centroid
)
normal
=
r
[:,
2
]
rotated_points
=
(
points
-
centroid
)
@
r
result_2d
=
fit_circle_in_2d
(
rotated_points
[:,
:
2
],
n_points
=
n_points
,
angles
=
angles
)
center_3d
=
result_2d
.
center
@
r
[:,
:
2
].
t
()
+
centroid
n_generated_points
=
result_2d
.
generated_points
.
shape
[
0
]
if
n_generated_points
>
0
:
generated_points_in_plane
=
torch
.
cat
(
[
result_2d
.
generated_points
,
torch
.
zeros_like
(
result_2d
.
generated_points
[:,
:
1
]),
],
dim
=
1
,
)
if
offset
is
not
None
:
if
up
is
None
:
raise
ValueError
(
"Missing `up` input for interpreting offset"
)
with
torch
.
no_grad
():
swap
=
torch
.
dot
(
up
,
normal
)
<
0
if
swap
:
# We need some rotation which takes +z to -z. Here's one.
generated_points_in_plane
+=
offset
*
offset
.
new_tensor
([
1
,
-
1
,
-
1
])
else
:
generated_points_in_plane
+=
offset
generated_points
=
generated_points_in_plane
@
r
.
t
()
+
centroid
else
:
generated_points
=
points
.
new_zeros
(
0
,
3
)
return
Circle3D
(
radius
=
result_2d
.
radius
,
center
=
center_3d
,
normal
=
normal
,
generated_points
=
generated_points
,
)
pytorch3d/implicitron/tools/config.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
dataclasses
import
inspect
import
warnings
from
collections
import
Counter
,
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
cast
from
omegaconf
import
DictConfig
,
OmegaConf
,
open_dict
"""
This functionality allows a configurable system to be determined in a dataclass-type
way. It is a generalization of omegaconf's "structured", in the dataclass case.
Core functionality:
- Configurable -- A base class used to label a class as being one which uses this
system. Uses class members and __post_init__ like a dataclass.
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
- get_default_args -- gets an omegaconf.DictConfig for initializing
a given class or calling a given function.
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
In addition, a Configurable may contain members whose type is decided at runtime.
- ReplaceableBase -- As a base instead of Configurable, labels a class to say that
any child class can be used instead.
- registry -- A global store of named child classes of ReplaceableBase classes.
Used as `@registry.register` decorator on class definition.
Additional utility functions:
- remove_unused_components -- used for simplifying a DictConfig instance.
- get_default_args_field -- default for DictConfig member of another configurable.
1. The simplest usage of this functionality is as follows. First a schema is defined
in dataclass style.
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
It can be used like
b_args = get_default_args(B)
b = B(**b_args)
In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
something like the following. (The modification itself is done by the function
`expand_args_fields`, which is called inside `get_default_args`.)
@dataclasses.dataclass
class A:
n: int = 9
@dataclasses.dataclass
class B:
a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
def __post_init__(self):
self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
you can give a base class and the implementation is looked up by name in the global
`registry` in this module. E.g.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
will expand to
@dataclasses.dataclass
class A:
k: int = 1
@dataclasses.dataclass
class A1(A):
m: int = 3
@dataclasses.dataclass
class A2(A):
n: str = "2"
@dataclasses.dataclass
class B:
a_class_type: str = "A2"
a_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
def __post_init__(self):
if self.a_class_type == "A1":
self.a = A1(**self.a_A1_args)
elif self.a_class_type == "A2":
self.a = A2(**self.a_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them.
In addition, you can call get_default_args on a function or class to get
the DictConfig of its defaulted arguments, assuming those are all things
which DictConfig is happy with. If you want to use such a thing as a member
of another configured class, `get_default_args_field` is a helper.
"""
_unprocessed_warning
:
str
=
(
" must be processed before it can be used."
+
" This is done by calling expand_args_fields "
+
"or get_default_args on it."
)
TYPE_SUFFIX
:
str
=
"_class_type"
ARGS_SUFFIX
:
str
=
"_args"
class
ReplaceableBase
:
"""
Base class for dataclass-style classes which
can be stored in the registry.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj
=
super
().
__new__
(
cls
)
if
cls
is
not
ReplaceableBase
and
not
_is_actually_dataclass
(
cls
):
warnings
.
warn
(
cls
.
__name__
+
_unprocessed_warning
)
return
obj
class
Configurable
:
"""
This indicates a class which is not ReplaceableBase
but still needs to be
expanded into a dataclass with expand_args_fields.
This expansion is delayed.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj
=
super
().
__new__
(
cls
)
if
cls
is
not
Configurable
and
not
_is_actually_dataclass
(
cls
):
warnings
.
warn
(
cls
.
__name__
+
_unprocessed_warning
)
return
obj
_X
=
TypeVar
(
"X"
,
bound
=
ReplaceableBase
)
class
_Registry
:
"""
Register from names to classes. In particular, we say that direct subclasses of
ReplaceableBase are "base classes" and we register subclasses of each base class
in a separate namespace.
"""
def
__init__
(
self
)
->
None
:
self
.
_mapping
:
Dict
[
Type
[
ReplaceableBase
],
Dict
[
str
,
Type
[
ReplaceableBase
]]
]
=
defaultdict
(
dict
)
def
register
(
self
,
some_class
:
Type
[
_X
])
->
Type
[
_X
]:
"""
A class decorator, to register a class in self.
"""
name
=
some_class
.
__name__
self
.
_register
(
some_class
,
name
=
name
)
return
some_class
def
_register
(
self
,
some_class
:
Type
[
ReplaceableBase
],
*
,
base_class
:
Optional
[
Type
[
ReplaceableBase
]]
=
None
,
name
:
str
,
)
->
None
:
"""
Register a new member.
Args:
cls: the new member
base_class: (optional) what the new member is a type for
name: name for the new member
"""
if
base_class
is
None
:
base_class
=
self
.
_base_class_from_class
(
some_class
)
if
base_class
is
None
:
raise
ValueError
(
f
"Cannot register
{
some_class
}
. Cannot tell what it is."
)
if
some_class
is
base_class
:
raise
ValueError
(
f
"Attempted to register the base class
{
some_class
}
"
)
self
.
_mapping
[
base_class
][
name
]
=
some_class
def
get
(
self
,
base_class_wanted
:
Type
[
ReplaceableBase
],
name
:
str
)
->
Type
[
ReplaceableBase
]:
"""
Retrieve a class from the registry by name
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
name: what to look for
Returns:
class type
"""
if
self
.
_is_base_class
(
base_class_wanted
):
base_class
=
base_class_wanted
else
:
base_class
=
self
.
_base_class_from_class
(
base_class_wanted
)
if
base_class
is
None
:
raise
ValueError
(
f
"Cannot look up
{
base_class_wanted
}
. Cannot tell what it is."
)
result
=
self
.
_mapping
[
base_class
].
get
(
name
)
if
result
is
None
:
raise
ValueError
(
f
"
{
name
}
has not been registered."
)
if
not
issubclass
(
result
,
base_class_wanted
):
raise
ValueError
(
f
"
{
name
}
resolves to
{
result
}
which does not subclass
{
base_class_wanted
}
"
)
return
result
def
get_all
(
self
,
base_class_wanted
:
Type
[
ReplaceableBase
]
)
->
List
[
Type
[
ReplaceableBase
]]:
"""
Retrieve all registered implementations from the registry
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
"""
if
self
.
_is_base_class
(
base_class_wanted
):
return
list
(
self
.
_mapping
[
base_class_wanted
].
values
())
base_class
=
self
.
_base_class_from_class
(
base_class_wanted
)
if
base_class
is
None
:
raise
ValueError
(
f
"Cannot look up
{
base_class_wanted
}
. Cannot tell what it is."
)
return
[
class_
for
class_
in
self
.
_mapping
[
base_class
].
values
()
if
issubclass
(
class_
,
base_class_wanted
)
and
class_
is
not
base_class_wanted
]
@
staticmethod
def
_is_base_class
(
some_class
:
Type
[
ReplaceableBase
])
->
bool
:
"""
Return whether the given type is a direct subclass of ReplaceableBase
and so gets used as a namespace.
"""
return
ReplaceableBase
in
some_class
.
__bases__
@
staticmethod
def
_base_class_from_class
(
some_class
:
Type
[
ReplaceableBase
],
)
->
Optional
[
Type
[
ReplaceableBase
]]:
"""
Find the parent class of some_class which inherits ReplaceableBase, or None
"""
for
base
in
some_class
.
mro
()[
-
3
::
-
1
]:
if
base
is
not
ReplaceableBase
and
issubclass
(
base
,
ReplaceableBase
):
return
base
return
None
# Global instance of the registry
registry
=
_Registry
()
def
_default_create
(
name
:
str
,
type_
:
Type
,
pluggable
:
bool
)
->
Callable
[[
Any
],
None
]:
"""
Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called
from run_auto_creation.
Args:
name: name of the member
type_: declared type of the member
pluggable: True if the member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
Returns:
Function taking one argument, the object whose member should be
initialized.
"""
def
inner
(
self
):
expand_args_fields
(
type_
)
args
=
getattr
(
self
,
name
+
ARGS_SUFFIX
)
setattr
(
self
,
name
,
type_
(
**
args
))
def
inner_pluggable
(
self
):
type_name
=
getattr
(
self
,
name
+
TYPE_SUFFIX
)
chosen_class
=
registry
.
get
(
type_
,
type_name
)
if
self
.
_known_implementations
.
get
(
type_name
,
chosen_class
)
is
not
chosen_class
:
# If this warning is raised, it means that a new definition of
# the chosen class has been registered since our class was processed
# (i.e. expanded). A DictConfig which comes from our get_default_args
# (which might have triggered the processing) will contain the old default
# values for the members of the chosen class. Changes to those defaults which
# were made in the redefinition will not be reflected here.
warnings
.
warn
(
f
"New implementation of
{
type_name
}
is being chosen."
)
expand_args_fields
(
chosen_class
)
args
=
getattr
(
self
,
f
"
{
name
}
_
{
type_name
}{
ARGS_SUFFIX
}
"
)
setattr
(
self
,
name
,
chosen_class
(
**
args
))
return
inner_pluggable
if
pluggable
else
inner
def
run_auto_creation
(
self
:
Any
)
->
None
:
"""
Run all the functions named in self._creation_functions.
"""
for
create_function
in
self
.
_creation_functions
:
getattr
(
self
,
create_function
)()
def
_is_configurable_class
(
C
)
->
bool
:
return
isinstance
(
C
,
type
)
and
issubclass
(
C
,
(
Configurable
,
ReplaceableBase
))
def
get_default_args
(
C
,
*
,
_do_not_process
:
Tuple
[
type
,
...]
=
())
->
DictConfig
:
"""
Get the DictConfig of args to call C - which might be a type or a function.
If C is a subclass of Configurable or ReplaceableBase, we make sure
it has been processed with expand_args_fields. If C is a dataclass,
including a subclass of Configurable or ReplaceableBase, the output
will be a typed DictConfig.
Args:
C: the class or function to be processed
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
while it is already being processed.
Returns:
new DictConfig object
"""
if
C
is
None
:
return
DictConfig
({})
if
_is_configurable_class
(
C
):
if
C
in
_do_not_process
:
raise
ValueError
(
f
"Internal recursion error. Need processed
{
C
}
,"
f
" but cannot get it. _do_not_process=
{
_do_not_process
}
"
)
# This is safe to run multiple times. It will return
# straight away if C has already been processed.
expand_args_fields
(
C
,
_do_not_process
=
_do_not_process
)
kwargs
=
{}
if
dataclasses
.
is_dataclass
(
C
):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out
=
OmegaConf
.
structured
(
C
)
exclude
=
getattr
(
C
,
"_processed_members"
,
())
with
open_dict
(
out
):
for
field
in
exclude
:
out
.
pop
(
field
,
None
)
return
out
if
_is_configurable_class
(
C
):
raise
ValueError
(
f
"Failed to process
{
C
}
"
)
# returns dict of keyword args of a callable C
sig
=
inspect
.
signature
(
C
)
for
pname
,
defval
in
dict
(
sig
.
parameters
).
items
():
if
defval
.
default
==
inspect
.
Parameter
.
empty
:
# print('skipping %s' % pname)
continue
else
:
kwargs
[
pname
]
=
copy
.
deepcopy
(
defval
.
default
)
return
DictConfig
(
kwargs
)
def
_is_actually_dataclass
(
some_class
)
->
bool
:
# Return whether the class some_class has been processed with
# the dataclass annotation. This is more specific than
# dataclasses.is_dataclass which returns True on anything
# deriving from a dataclass.
# Checking for __init__ would also work for our purpose.
return
"__dataclass_fields__"
in
some_class
.
__dict__
def
expand_args_fields
(
some_class
:
Type
[
_X
],
*
,
_do_not_process
:
Tuple
[
type
,
...]
=
()
)
->
Type
[
_X
]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.
For classes of type ReplaceableBase, you can add some_class to the registry before
or after calling this function. But potential inner classes need to be registered
before this function is run on the outer class.
The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace
x: X
and optionally
x_class_type: str = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: str = "UNDEFAULTED"
without adding the optional things if they are already there.
Similarly, if X is a subclass of Configurable,
x: X
and optionally
def create_x(self):...
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = X(self.x_args)
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
including those from base classes.
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
- _processed_members: a Set[str] of all the members which have been transformed.
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
and is called by this function, we let it specify any class currently
being processed, to make sure we don't try to process a class while
it is already being processed.
Returns:
some_class itself, which has been modified in place. This
allows this function to be used as a class decorator.
"""
if
_is_actually_dataclass
(
some_class
):
return
some_class
# The functions this class's run_auto_creation will run.
creation_functions
:
List
[
str
]
=
[]
# The classes which this type knows about from the registry
# We could use a weakref.WeakValueDictionary here which would mean
# that we don't warn if the class we should have expected is elsewhere
# unused.
known_implementations
:
Dict
[
str
,
Type
]
=
{}
# Names of members which have been processed.
processed_members
:
Set
[
str
]
=
set
()
# For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is
# because dataclasses expect to inherit dataclasses and not unprocessed
# dataclasses.
for
base
in
some_class
.
mro
()[
-
3
:
0
:
-
1
]:
if
base
is
ReplaceableBase
:
continue
if
base
is
Configurable
:
continue
if
not
issubclass
(
base
,
(
Configurable
,
ReplaceableBase
)):
continue
expand_args_fields
(
base
,
_do_not_process
=
_do_not_process
)
if
"_creation_functions"
in
base
.
__dict__
:
creation_functions
.
extend
(
base
.
_creation_functions
)
if
"_known_implementations"
in
base
.
__dict__
:
known_implementations
.
update
(
base
.
_known_implementations
)
if
"_processed_members"
in
base
.
__dict__
:
processed_members
.
update
(
base
.
_processed_members
)
to_process
:
List
[
Tuple
[
str
,
Type
,
bool
]]
=
[]
if
"__annotations__"
in
some_class
.
__dict__
:
for
name
,
type_
in
some_class
.
__annotations__
.
items
():
if
not
isinstance
(
type_
,
type
):
# type_ could be something like typing.Tuple
continue
if
(
issubclass
(
type_
,
ReplaceableBase
)
and
ReplaceableBase
in
type_
.
__bases__
):
to_process
.
append
((
name
,
type_
,
True
))
elif
issubclass
(
type_
,
Configurable
):
to_process
.
append
((
name
,
type_
,
False
))
for
name
,
type_
,
pluggable
in
to_process
:
_process_member
(
name
=
name
,
type_
=
type_
,
pluggable
=
pluggable
,
some_class
=
cast
(
type
,
some_class
),
creation_functions
=
creation_functions
,
_do_not_process
=
_do_not_process
,
known_implementations
=
known_implementations
,
)
processed_members
.
add
(
name
)
for
key
,
count
in
Counter
(
creation_functions
).
items
():
if
count
>
1
:
warnings
.
warn
(
f
"Clash with
{
key
}
in a base class."
)
some_class
.
_creation_functions
=
tuple
(
creation_functions
)
some_class
.
_processed_members
=
processed_members
some_class
.
_known_implementations
=
known_implementations
dataclasses
.
dataclass
(
eq
=
False
)(
some_class
)
return
some_class
def
get_default_args_field
(
C
,
*
,
_do_not_process
:
Tuple
[
type
,
...]
=
()):
"""
Get a dataclass field which defaults to get_default_args(...)
Args:
As for get_default_args.
Returns:
function to return new DictConfig object
"""
def
create
():
return
get_default_args
(
C
,
_do_not_process
=
_do_not_process
)
return
dataclasses
.
field
(
default_factory
=
create
)
def
_process_member
(
*
,
name
:
str
,
type_
:
Type
,
pluggable
:
bool
,
some_class
:
Type
,
creation_functions
:
List
[
str
],
_do_not_process
:
Tuple
[
type
,
...],
known_implementations
:
Dict
[
str
,
Type
],
)
->
None
:
"""
Make the modification (of expand_args_fields) to some_class for a single member.
Args:
name: member name
type_: member declared type
plugglable: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields.
known_implementations: (MODIFIED IN PLACE) known types from the registry
"""
# Because we are adding defaultable members, make
# sure they go at the end of __annotations__ in case
# there are non-defaulted standard class members.
del
some_class
.
__annotations__
[
name
]
if
pluggable
:
type_name
=
name
+
TYPE_SUFFIX
if
type_name
not
in
some_class
.
__annotations__
:
some_class
.
__annotations__
[
type_name
]
=
str
setattr
(
some_class
,
type_name
,
"UNDEFAULTED"
)
for
derived_type
in
registry
.
get_all
(
type_
):
if
derived_type
in
_do_not_process
:
continue
if
issubclass
(
derived_type
,
some_class
):
# When derived_type is some_class we have a simple
# recursion to avoid. When it's a strict subclass the
# situation is even worse.
continue
known_implementations
[
derived_type
.
__name__
]
=
derived_type
args_name
=
f
"
{
name
}
_
{
derived_type
.
__name__
}{
ARGS_SUFFIX
}
"
if
args_name
in
some_class
.
__annotations__
:
raise
ValueError
(
f
"Cannot generate
{
args_name
}
because it is already present."
)
some_class
.
__annotations__
[
args_name
]
=
DictConfig
setattr
(
some_class
,
args_name
,
get_default_args_field
(
derived_type
,
_do_not_process
=
_do_not_process
+
(
some_class
,)
),
)
else
:
args_name
=
name
+
ARGS_SUFFIX
if
args_name
in
some_class
.
__annotations__
:
raise
ValueError
(
f
"Cannot generate
{
args_name
}
because it is already present."
)
if
issubclass
(
type_
,
some_class
)
or
type_
in
_do_not_process
:
raise
ValueError
(
f
"Cannot process
{
type_
}
inside
{
some_class
}
"
)
some_class
.
__annotations__
[
args_name
]
=
DictConfig
setattr
(
some_class
,
args_name
,
get_default_args_field
(
type_
,
_do_not_process
=
_do_not_process
+
(
some_class
,),
),
)
creation_function_name
=
f
"create_
{
name
}
"
if
not
hasattr
(
some_class
,
creation_function_name
):
setattr
(
some_class
,
creation_function_name
,
_default_create
(
name
,
type_
,
pluggable
),
)
creation_functions
.
append
(
creation_function_name
)
def
remove_unused_components
(
dict_
:
DictConfig
)
->
None
:
"""
Assuming dict_ represents the state of a configurable,
modify it to remove all the portions corresponding to
pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed.
Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance
"""
keys
=
[
key
for
key
in
dict_
if
isinstance
(
key
,
str
)]
suffix_length
=
len
(
TYPE_SUFFIX
)
replaceables
=
[
key
[:
-
suffix_length
]
for
key
in
keys
if
key
.
endswith
(
TYPE_SUFFIX
)]
args_keys
=
[
key
for
key
in
keys
if
key
.
endswith
(
ARGS_SUFFIX
)]
for
replaceable
in
replaceables
:
selected_type
=
dict_
[
replaceable
+
TYPE_SUFFIX
]
expect
=
replaceable
+
"_"
+
selected_type
+
ARGS_SUFFIX
with
open_dict
(
dict_
):
for
key
in
args_keys
:
if
key
.
startswith
(
replaceable
+
"_"
)
and
key
!=
expect
:
del
dict_
[
key
]
for
key
in
dict_
:
if
isinstance
(
dict_
.
get
(
key
),
DictConfig
):
remove_unused_components
(
dict_
[
key
])
pytorch3d/implicitron/tools/depth_cleanup.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
Fu
from
pytorch3d.ops
import
wmean
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.structures
import
Pointclouds
def
cleanup_eval_depth
(
point_cloud
:
Pointclouds
,
camera
:
CamerasBase
,
depth
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
sigma
:
float
=
0.01
,
image
=
None
,
):
ba
,
_
,
H
,
W
=
depth
.
shape
pcl
=
point_cloud
.
points_padded
()
n_pts
=
point_cloud
.
num_points_per_cloud
()
pcl_mask
=
(
torch
.
arange
(
pcl
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
pcl
.
device
)[
None
]
<
n_pts
[:,
None
]
).
type_as
(
pcl
)
pcl_proj
=
camera
.
transform_points
(
pcl
,
eps
=
1e-2
)[...,
:
-
1
]
pcl_depth
=
camera
.
get_world_to_view_transform
().
transform_points
(
pcl
)[...,
-
1
]
depth_and_idx
=
torch
.
cat
(
(
depth
,
torch
.
arange
(
H
*
W
).
view
(
1
,
1
,
H
,
W
).
expand
(
ba
,
1
,
H
,
W
).
type_as
(
depth
),
),
dim
=
1
,
)
depth_and_idx_sampled
=
Fu
.
grid_sample
(
depth_and_idx
,
-
pcl_proj
[:,
None
],
mode
=
"nearest"
)[:,
:,
0
].
view
(
ba
,
2
,
-
1
)
depth_sampled
,
idx_sampled
=
depth_and_idx_sampled
.
split
([
1
,
1
],
dim
=
1
)
df
=
(
depth_sampled
[:,
0
]
-
pcl_depth
).
abs
()
# the threshold is a sigma-multiple of the standard deviation of the depth
mu
=
wmean
(
depth
.
view
(
ba
,
-
1
,
1
),
mask
.
view
(
ba
,
-
1
)).
view
(
ba
,
1
)
std
=
(
wmean
((
depth
.
view
(
ba
,
-
1
)
-
mu
).
view
(
ba
,
-
1
,
1
)
**
2
,
mask
.
view
(
ba
,
-
1
))
.
clamp
(
1e-4
)
.
sqrt
()
.
view
(
ba
,
-
1
)
)
good_df_thr
=
std
*
sigma
good_depth
=
(
df
<=
good_df_thr
).
float
()
*
pcl_mask
perc_kept
=
good_depth
.
sum
(
dim
=
1
)
/
pcl_mask
.
sum
(
dim
=
1
).
clamp
(
1
)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster
=
torch
.
zeros_like
(
depth
).
view
(
ba
,
-
1
)
# pyre-ignore[16]: scatter_add_
good_depth_raster
.
scatter_add_
(
1
,
torch
.
round
(
idx_sampled
[:,
0
]).
long
(),
good_depth
)
good_depth_mask
=
(
good_depth_raster
.
view
(
ba
,
1
,
H
,
W
)
>
0
).
float
()
# if float(torch.rand(1)) > 0.95:
# depth_ok = depth * good_depth_mask
# # visualize
# visdom_env = 'depth_cleanup_dbg'
# from visdom import Visdom
# # from tools.vis_utils import make_depth_image
# from pytorch3d.vis.plotly_vis import plot_scene
# viz = Visdom()
# show_pcls = {
# 'pointclouds': point_cloud,
# }
# for d, nm in zip(
# (depth, depth_ok),
# ('pointclouds_unproj', 'pointclouds_unproj_ok'),
# ):
# pointclouds_unproj = get_rgbd_point_cloud(
# camera, image, d,
# )
# if int(pointclouds_unproj.num_points_per_cloud()) > 0:
# show_pcls[nm] = pointclouds_unproj
# scene_dict = {'1': {
# **show_pcls,
# 'cameras': camera,
# }}
# scene = plot_scene(
# scene_dict,
# pointcloud_max_points=5000,
# pointcloud_marker_size=1.5,
# camera_scale=1.0,
# )
# viz.plotlyplot(scene, env=visdom_env, win='scene')
# # depth_image_ok = make_depth_image(depths_ok, masks)
# # viz.images(depth_image_ok, env=visdom_env, win='depth_ok')
# # depth_image = make_depth_image(depths, masks)
# # viz.images(depth_image, env=visdom_env, win='depth')
# # # viz.images(rgb_rendered, env=visdom_env, win='images_render')
# # viz.images(images, env=visdom_env, win='images')
# import pdb; pdb.set_trace()
return
good_depth_mask
pytorch3d/implicitron/tools/eval_video_trajectory.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Optional
,
Tuple
import
torch
from
pytorch3d.common.compat
import
eigh
from
pytorch3d.implicitron.tools.circle_fitting
import
fit_circle_in_3d
from
pytorch3d.renderer
import
PerspectiveCameras
,
look_at_view_transform
from
pytorch3d.transforms
import
Scale
def
generate_eval_video_cameras
(
train_cameras
,
n_eval_cams
:
int
=
100
,
trajectory_type
:
str
=
"figure_eight"
,
trajectory_scale
:
float
=
0.2
,
scene_center
:
Tuple
[
float
,
float
,
float
]
=
(
0.0
,
0.0
,
0.0
),
up
:
Tuple
[
float
,
float
,
float
]
=
(
0.0
,
0.0
,
1.0
),
focal_length
:
Optional
[
torch
.
FloatTensor
]
=
None
,
principal_point
:
Optional
[
torch
.
FloatTensor
]
=
None
,
time
:
Optional
[
torch
.
FloatTensor
]
=
None
,
infer_up_as_plane_normal
:
bool
=
True
,
traj_offset
:
Optional
[
Tuple
[
float
,
float
,
float
]]
=
None
,
traj_offset_canonical
:
Optional
[
Tuple
[
float
,
float
,
float
]]
=
None
,
)
->
PerspectiveCameras
:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
Args:
train_dataset: The training dataset object.
n_eval_cams: Number of cameras in the trajectory.
trajectory_type: The type of the camera trajectory. Can be one of:
circular_lsq_fit: Camera centers follow a trajectory obtained
by fitting a 3D circle to train_cameras centers.
All cameras are looking towards scene_center.
figure_eight: Figure-of-8 trajectory around the center of the
central camera of the training dataset.
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
of a figure-eight knot
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
trajectory_scale: The extent of the trajectory.
up: The "up" vector of the scene (=the normal of the scene floor).
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if
trajectory_type
in
(
"figure_eight"
,
"trefoil_knot"
,
"figure_eight_knot"
):
cam_centers
=
train_cameras
.
get_camera_center
()
# get the nearest camera center to the mean of centers
mean_camera_idx
=
(
((
cam_centers
-
cam_centers
.
mean
(
dim
=
0
)[
None
])
**
2
)
.
sum
(
dim
=
1
)
.
min
(
dim
=
0
)
.
indices
)
# generate the knot trajectory in canonical coords
if
time
is
None
:
time
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
n_eval_cams
+
1
)[:
n_eval_cams
]
else
:
assert
time
.
numel
()
==
n_eval_cams
if
trajectory_type
==
"trefoil_knot"
:
traj
=
_trefoil_knot
(
time
)
elif
trajectory_type
==
"figure_eight_knot"
:
traj
=
_figure_eight_knot
(
time
)
elif
trajectory_type
==
"figure_eight"
:
traj
=
_figure_eight
(
time
)
else
:
raise
ValueError
(
f
"bad trajectory type:
{
trajectory_type
}
"
)
traj
[:,
2
]
-=
traj
[:,
2
].
max
()
# transform the canonical knot to the coord frame of the mean camera
mean_camera
=
PerspectiveCameras
(
**
{
k
:
getattr
(
train_cameras
,
k
)[[
int
(
mean_camera_idx
)]]
for
k
in
(
"focal_length"
,
"principal_point"
,
"R"
,
"T"
)
}
)
traj_trans
=
Scale
(
cam_centers
.
std
(
dim
=
0
).
mean
()
*
trajectory_scale
).
compose
(
mean_camera
.
get_world_to_view_transform
().
inverse
()
)
if
traj_offset_canonical
is
not
None
:
traj_trans
=
traj_trans
.
translate
(
torch
.
FloatTensor
(
traj_offset_canonical
)[
None
].
to
(
traj
)
)
traj
=
traj_trans
.
transform_points
(
traj
)
plane_normal
=
_fit_plane
(
cam_centers
)[:,
0
]
if
infer_up_as_plane_normal
:
up
=
_disambiguate_normal
(
plane_normal
,
up
)
elif
trajectory_type
==
"circular_lsq_fit"
:
### fit plane to the camera centers
# get the center of the plane as the median of the camera centers
cam_centers
=
train_cameras
.
get_camera_center
()
if
time
is
not
None
:
angle
=
time
else
:
angle
=
torch
.
linspace
(
0
,
2.0
*
math
.
pi
,
n_eval_cams
).
to
(
cam_centers
)
fit
=
fit_circle_in_3d
(
cam_centers
,
angles
=
angle
,
offset
=
angle
.
new_tensor
(
traj_offset_canonical
)
if
traj_offset_canonical
is
not
None
else
None
,
up
=
angle
.
new_tensor
(
up
),
)
traj
=
fit
.
generated_points
# scalethe trajectory
_t_mu
=
traj
.
mean
(
dim
=
0
,
keepdim
=
True
)
traj
=
(
traj
-
_t_mu
)
*
trajectory_scale
+
_t_mu
plane_normal
=
fit
.
normal
if
infer_up_as_plane_normal
:
up
=
_disambiguate_normal
(
plane_normal
,
up
)
else
:
raise
ValueError
(
f
"Uknown trajectory_type
{
trajectory_type
}
."
)
if
traj_offset
is
not
None
:
traj
=
traj
+
torch
.
FloatTensor
(
traj_offset
)[
None
].
to
(
traj
)
# point all cameras towards the center of the scene
R
,
T
=
look_at_view_transform
(
eye
=
traj
,
at
=
(
scene_center
,),
# (1, 3)
up
=
(
up
,),
# (1, 3)
device
=
traj
.
device
,
)
# get the average focal length and principal point
if
focal_length
is
None
:
focal_length
=
train_cameras
.
focal_length
.
mean
(
dim
=
0
).
repeat
(
n_eval_cams
,
1
)
if
principal_point
is
None
:
principal_point
=
train_cameras
.
principal_point
.
mean
(
dim
=
0
).
repeat
(
n_eval_cams
,
1
)
test_cameras
=
PerspectiveCameras
(
focal_length
=
focal_length
,
principal_point
=
principal_point
,
R
=
R
,
T
=
T
,
device
=
focal_length
.
device
,
)
# _visdom_plot_scene(
# train_cameras,
# test_cameras,
# )
return
test_cameras
def
_disambiguate_normal
(
normal
,
up
):
up_t
=
torch
.
tensor
(
up
).
to
(
normal
)
flip
=
(
up_t
*
normal
).
sum
().
sign
()
up
=
normal
*
flip
up
=
up
.
tolist
()
return
up
def
_fit_plane
(
x
):
x
=
x
-
x
.
mean
(
dim
=
0
)[
None
]
cov
=
(
x
.
t
()
@
x
)
/
x
.
shape
[
0
]
_
,
e_vec
=
eigh
(
cov
)
return
e_vec
def
_visdom_plot_scene
(
train_cameras
,
test_cameras
,
)
->
None
:
from
pytorch3d.vis.plotly_vis
import
plot_scene
p
=
plot_scene
(
{
"scene"
:
{
"train_cams"
:
train_cameras
,
"test_cams"
:
test_cameras
,
}
}
)
from
visdom
import
Visdom
viz
=
Visdom
()
viz
.
plotlyplot
(
p
,
env
=
"cam_traj_dbg"
,
win
=
"cam_trajs"
)
import
pdb
pdb
.
set_trace
()
def
_figure_eight_knot
(
t
:
torch
.
Tensor
,
z_scale
:
float
=
0.5
):
x
=
(
2
+
(
2
*
t
).
cos
())
*
(
3
*
t
).
cos
()
y
=
(
2
+
(
2
*
t
).
cos
())
*
(
3
*
t
).
sin
()
z
=
(
4
*
t
).
sin
()
*
z_scale
return
torch
.
stack
((
x
,
y
,
z
),
dim
=-
1
)
def
_trefoil_knot
(
t
:
torch
.
Tensor
,
z_scale
:
float
=
0.5
):
x
=
t
.
sin
()
+
2
*
(
2
*
t
).
sin
()
y
=
t
.
cos
()
-
2
*
(
2
*
t
).
cos
()
z
=
-
(
3
*
t
).
sin
()
*
z_scale
return
torch
.
stack
((
x
,
y
,
z
),
dim
=-
1
)
def
_figure_eight
(
t
:
torch
.
Tensor
,
z_scale
:
float
=
0.5
):
x
=
t
.
cos
()
y
=
(
2
*
t
).
sin
()
/
2
z
=
t
.
sin
()
*
z_scale
return
torch
.
stack
((
x
,
y
,
z
),
dim
=-
1
)
pytorch3d/implicitron/tools/image_utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Union
import
torch
def
mask_background
(
image_rgb
:
torch
.
Tensor
,
mask_fg
:
torch
.
Tensor
,
dim_color
:
int
=
1
,
bg_color
:
Union
[
torch
.
Tensor
,
str
,
float
]
=
0.0
,
)
->
torch
.
Tensor
:
"""
Mask the background input image tensor `image_rgb` with `bg_color`.
The background regions are obtained from the binary foreground segmentation
mask `mask_fg`.
"""
tgt_view
=
[
1
,
1
,
1
,
1
]
tgt_view
[
dim_color
]
=
3
# obtain the background color tensor
if
isinstance
(
bg_color
,
torch
.
Tensor
):
bg_color_t
=
bg_color
.
view
(
1
,
3
,
1
,
1
).
clone
().
to
(
image_rgb
)
elif
isinstance
(
bg_color
,
float
):
bg_color_t
=
torch
.
tensor
(
[
bg_color
]
*
3
,
device
=
image_rgb
.
device
,
dtype
=
image_rgb
.
dtype
).
view
(
*
tgt_view
)
elif
isinstance
(
bg_color
,
str
):
if
bg_color
==
"white"
:
bg_color_t
=
image_rgb
.
new_ones
(
tgt_view
)
elif
bg_color
==
"black"
:
bg_color_t
=
image_rgb
.
new_zeros
(
tgt_view
)
else
:
raise
ValueError
(
_invalid_color_error_msg
(
bg_color
))
else
:
raise
ValueError
(
_invalid_color_error_msg
(
bg_color
))
# cast to the image_rgb's type
mask_fg
=
mask_fg
.
type_as
(
image_rgb
)
# mask the bg
image_masked
=
mask_fg
*
image_rgb
+
(
1
-
mask_fg
)
*
bg_color_t
return
image_masked
def
_invalid_color_error_msg
(
bg_color
)
->
str
:
return
(
f
"Invalid bg_color=
{
bg_color
}
. Plese set bg_color to a 3-element"
+
" tensor. or a string (white | black), or a float."
)
pytorch3d/implicitron/tools/metric_utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
Optional
,
Tuple
import
torch
from
torch.nn
import
functional
as
F
def
eval_depth
(
pred
:
torch
.
Tensor
,
gt
:
torch
.
Tensor
,
crop
:
int
=
1
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
get_best_scale
:
bool
=
True
,
mask_thr
:
float
=
0.5
,
best_scale_clamp_thr
:
float
=
1e-4
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Evaluate the depth error between the prediction `pred` and the ground
truth `gt`.
Args:
pred: A tensor of shape (N, 1, H, W) denoting the predicted depth maps.
gt: A tensor of shape (N, 1, H, W) denoting the ground truth depth maps.
crop: The number of pixels to crop from the border.
mask: A mask denoting the valid regions of the gt depth.
get_best_scale: If `True`, estimates a scaling factor of the predicted depth
that yields the best mean squared error between `pred` and `gt`.
This is typically enabled for cases where predicted reconstructions
are inherently defined up to an arbitrary scaling factor.
mask_thr: A constant used to threshold the `mask` to specify the valid
regions.
best_scale_clamp_thr: The threshold for clamping the divisor in best
scale estimation.
Returns:
mse_depth: Mean squared error between `pred` and `gt`.
abs_depth: Mean absolute difference between `pred` and `gt`.
"""
# chuck out the border
if
crop
>
0
:
gt
=
gt
[:,
:,
crop
:
-
crop
,
crop
:
-
crop
]
pred
=
pred
[:,
:,
crop
:
-
crop
,
crop
:
-
crop
]
if
mask
is
not
None
:
# mult gt by mask
if
crop
>
0
:
mask
=
mask
[:,
:,
crop
:
-
crop
,
crop
:
-
crop
]
gt
=
gt
*
(
mask
>
mask_thr
).
float
()
dmask
=
(
gt
>
0.0
).
float
()
dmask_mass
=
torch
.
clamp
(
dmask
.
sum
((
1
,
2
,
3
)),
1e-4
)
if
get_best_scale
:
# mult preds by a scalar "scale_best"
# s.t. we get best possible mse error
scale_best
=
estimate_depth_scale_factor
(
pred
,
gt
,
dmask
,
best_scale_clamp_thr
)
pred
=
pred
*
scale_best
[:,
None
,
None
,
None
]
df
=
gt
-
pred
mse_depth
=
(
dmask
*
(
df
**
2
)).
sum
((
1
,
2
,
3
))
/
dmask_mass
abs_depth
=
(
dmask
*
df
.
abs
()).
sum
((
1
,
2
,
3
))
/
dmask_mass
return
mse_depth
,
abs_depth
def
estimate_depth_scale_factor
(
pred
,
gt
,
mask
,
clamp_thr
):
xy
=
pred
*
gt
*
mask
xx
=
pred
*
pred
*
mask
scale_best
=
xy
.
mean
((
1
,
2
,
3
))
/
torch
.
clamp
(
xx
.
mean
((
1
,
2
,
3
)),
clamp_thr
)
return
scale_best
def
calc_psnr
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
"""
mse
=
calc_mse
(
x
,
y
,
mask
=
mask
)
psnr
=
torch
.
log10
(
mse
.
clamp
(
1e-10
))
*
(
-
10.0
)
return
psnr
def
calc_mse
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Calculates the mean square error between tensors `x` and `y`.
"""
if
mask
is
None
:
return
torch
.
mean
((
x
-
y
)
**
2
)
else
:
return
(((
x
-
y
)
**
2
)
*
mask
).
sum
()
/
mask
.
expand_as
(
x
).
sum
().
clamp
(
1e-5
)
def
calc_bce
(
pred
:
torch
.
Tensor
,
gt
:
torch
.
Tensor
,
equal_w
:
bool
=
True
,
pred_eps
:
float
=
0.01
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
lerp_bound
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
"""
Calculates the binary cross entropy.
"""
if
pred_eps
>
0.0
:
# up/low bound the predictions
pred
=
torch
.
clamp
(
pred
,
pred_eps
,
1.0
-
pred_eps
)
if
mask
is
None
:
mask
=
torch
.
ones_like
(
gt
)
if
equal_w
:
mask_fg
=
(
gt
>
0.5
).
float
()
*
mask
mask_bg
=
(
1
-
mask_fg
)
*
mask
weight
=
mask_fg
/
mask_fg
.
sum
().
clamp
(
1.0
)
+
mask_bg
/
mask_bg
.
sum
().
clamp
(
1.0
)
# weight sum should be at this point ~2
weight
=
weight
*
(
weight
.
numel
()
/
weight
.
sum
().
clamp
(
1.0
))
else
:
weight
=
torch
.
ones_like
(
gt
)
*
mask
if
lerp_bound
is
not
None
:
return
binary_cross_entropy_lerp
(
pred
,
gt
,
weight
,
lerp_bound
)
else
:
return
F
.
binary_cross_entropy
(
pred
,
gt
,
reduction
=
"mean"
,
weight
=
weight
)
def
binary_cross_entropy_lerp
(
pred
:
torch
.
Tensor
,
gt
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
lerp_bound
:
float
,
):
"""
Binary cross entropy which avoids exploding gradients by linearly
extrapolating the log function for log(1-pred) mad log(pred) whenever
pred or 1-pred is smaller than lerp_bound.
"""
loss
=
log_lerp
(
1
-
pred
,
lerp_bound
)
*
(
1
-
gt
)
+
log_lerp
(
pred
,
lerp_bound
)
*
gt
loss_reduced
=
-
(
loss
*
weight
).
sum
()
/
weight
.
sum
().
clamp
(
1e-4
)
return
loss_reduced
def
log_lerp
(
x
:
torch
.
Tensor
,
b
:
float
):
"""
Linearly extrapolated log for x < b.
"""
assert
b
>
0
return
torch
.
where
(
x
>=
b
,
x
.
log
(),
math
.
log
(
b
)
+
(
x
-
b
)
/
b
)
def
rgb_l1
(
pred
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Calculates the mean absolute error between the predicted colors `pred`
and ground truth colors `target`.
"""
if
mask
is
None
:
mask
=
torch
.
ones_like
(
pred
[:,
:
1
])
return
((
pred
-
target
).
abs
()
*
mask
).
sum
(
dim
=
(
1
,
2
,
3
))
/
mask
.
sum
(
dim
=
(
1
,
2
,
3
)
).
clamp
(
1
)
def
huber
(
dfsq
:
torch
.
Tensor
,
scaling
:
float
=
0.03
)
->
torch
.
Tensor
:
"""
Calculates the huber function of the input squared error `dfsq`.
The function smoothly transitions from a region with unit gradient
to a hyperbolic function at `dfsq=scaling`.
"""
loss
=
(
safe_sqrt
(
1
+
dfsq
/
(
scaling
*
scaling
),
eps
=
1e-4
)
-
1
)
*
scaling
return
loss
def
neg_iou_loss
(
predict
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
return
1.0
-
iou
(
predict
,
target
,
mask
=
mask
)
def
safe_sqrt
(
A
:
torch
.
Tensor
,
eps
:
float
=
float
(
1e-4
))
->
torch
.
Tensor
:
"""
performs safe differentiable sqrt
"""
return
(
torch
.
clamp
(
A
,
float
(
0
))
+
eps
).
sqrt
()
def
iou
(
predict
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
dims
=
tuple
(
range
(
predict
.
dim
())[
1
:])
if
mask
is
not
None
:
predict
=
predict
*
mask
target
=
target
*
mask
intersect
=
(
predict
*
target
).
sum
(
dims
)
union
=
(
predict
+
target
-
predict
*
target
).
sum
(
dims
)
+
1e-4
return
(
intersect
/
union
).
sum
()
/
intersect
.
numel
()
def
beta_prior
(
pred
:
torch
.
Tensor
,
cap
:
float
=
0.1
)
->
torch
.
Tensor
:
if
cap
<=
0.0
:
raise
ValueError
(
"capping should be positive to avoid unbound loss"
)
min_value
=
math
.
log
(
cap
)
+
math
.
log
(
cap
+
1.0
)
return
(
torch
.
log
(
pred
+
cap
)
+
torch
.
log
(
1.0
-
pred
+
cap
)).
mean
()
-
min_value
pytorch3d/implicitron/tools/model_io.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
glob
import
os
import
shutil
import
tempfile
import
torch
def
load_stats
(
flstats
):
from
pytorch3d.implicitron.tools.stats
import
Stats
try
:
stats
=
Stats
.
load
(
flstats
)
except
:
print
(
"Cant load stats! %s"
%
flstats
)
stats
=
None
return
stats
def
get_model_path
(
fl
)
->
str
:
fl
=
os
.
path
.
splitext
(
fl
)[
0
]
flmodel
=
"%s.pth"
%
fl
return
flmodel
def
get_optimizer_path
(
fl
)
->
str
:
fl
=
os
.
path
.
splitext
(
fl
)[
0
]
flopt
=
"%s_opt.pth"
%
fl
return
flopt
def
get_stats_path
(
fl
,
eval_results
:
bool
=
False
):
fl
=
os
.
path
.
splitext
(
fl
)[
0
]
if
eval_results
:
for
postfix
in
(
"_2"
,
""
):
flstats
=
os
.
path
.
join
(
os
.
path
.
dirname
(
fl
),
f
"stats_test
{
postfix
}
.jgz"
)
if
os
.
path
.
isfile
(
flstats
):
break
else
:
flstats
=
"%s_stats.jgz"
%
fl
# pyre-fixme[61]: `flstats` is undefined, or not always defined.
return
flstats
def
safe_save_model
(
model
,
stats
,
fl
,
optimizer
=
None
,
cfg
=
None
)
->
None
:
"""
This functions stores model files safely so that no model files exist on the
file system in case the saving procedure gets interrupted.
This is done first by saving the model files to a temporary directory followed
by (atomic) moves to the target location. Note, that this can still result
in a corrupt set of model files in case interruption happens while performing
the moves. It is however quite improbable that a crash would occur right at
this time.
"""
print
(
f
"saving model files safely to
{
fl
}
"
)
# first store everything to a tmpdir
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
tmpfl
=
os
.
path
.
join
(
tmpdir
,
os
.
path
.
split
(
fl
)[
-
1
])
stored_tmp_fls
=
save_model
(
model
,
stats
,
tmpfl
,
optimizer
=
optimizer
,
cfg
=
cfg
)
tgt_fls
=
[
(
os
.
path
.
join
(
os
.
path
.
split
(
fl
)[
0
],
os
.
path
.
split
(
tmpfl
)[
-
1
])
if
(
tmpfl
is
not
None
)
else
None
)
for
tmpfl
in
stored_tmp_fls
]
# then move from the tmpdir to the right location
for
tmpfl
,
tgt_fl
in
zip
(
stored_tmp_fls
,
tgt_fls
):
if
tgt_fl
is
None
:
continue
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
shutil
.
move
(
tmpfl
,
tgt_fl
)
def
save_model
(
model
,
stats
,
fl
,
optimizer
=
None
,
cfg
=
None
):
flstats
=
get_stats_path
(
fl
)
flmodel
=
get_model_path
(
fl
)
print
(
"saving model to %s"
%
flmodel
)
torch
.
save
(
model
.
state_dict
(),
flmodel
)
flopt
=
None
if
optimizer
is
not
None
:
flopt
=
get_optimizer_path
(
fl
)
print
(
"saving optimizer to %s"
%
flopt
)
torch
.
save
(
optimizer
.
state_dict
(),
flopt
)
print
(
"saving model stats to %s"
%
flstats
)
stats
.
save
(
flstats
)
return
flstats
,
flmodel
,
flopt
def
load_model
(
fl
):
flstats
=
get_stats_path
(
fl
)
flmodel
=
get_model_path
(
fl
)
flopt
=
get_optimizer_path
(
fl
)
model_state_dict
=
torch
.
load
(
flmodel
)
stats
=
load_stats
(
flstats
)
if
os
.
path
.
isfile
(
flopt
):
optimizer
=
torch
.
load
(
flopt
)
else
:
optimizer
=
None
return
model_state_dict
,
stats
,
optimizer
def
parse_epoch_from_model_path
(
model_path
)
->
int
:
return
int
(
os
.
path
.
split
(
model_path
)[
-
1
].
replace
(
".pth"
,
""
).
replace
(
"model_epoch_"
,
""
)
)
def
get_checkpoint
(
exp_dir
,
epoch
):
fl
=
os
.
path
.
join
(
exp_dir
,
"model_epoch_%08d.pth"
%
epoch
)
return
fl
def
find_last_checkpoint
(
exp_dir
,
any_path
:
bool
=
False
,
all_checkpoints
:
bool
=
False
):
if
any_path
:
exts
=
[
".pth"
,
"_stats.jgz"
,
"_opt.pth"
]
else
:
exts
=
[
".pth"
]
for
ext
in
exts
:
fls
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
glob
.
escape
(
exp_dir
),
"model_epoch_"
+
"[0-9]"
*
8
+
ext
)
)
)
if
len
(
fls
)
>
0
:
break
# pyre-fixme[61]: `fls` is undefined, or not always defined.
if
len
(
fls
)
==
0
:
fl
=
None
else
:
if
all_checkpoints
:
# pyre-fixme[61]: `fls` is undefined, or not always defined.
fl
=
[
f
[
0
:
-
len
(
ext
)]
+
".pth"
for
f
in
fls
]
else
:
fl
=
fls
[
-
1
][
0
:
-
len
(
ext
)]
+
".pth"
return
fl
def
purge_epoch
(
exp_dir
,
epoch
)
->
None
:
model_path
=
get_checkpoint
(
exp_dir
,
epoch
)
for
file_path
in
[
model_path
,
get_optimizer_path
(
model_path
),
get_stats_path
(
model_path
),
]:
if
os
.
path
.
isfile
(
file_path
):
print
(
"deleting %s"
%
file_path
)
os
.
remove
(
file_path
)
pytorch3d/implicitron/tools/point_cloud_utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
,
Tuple
,
cast
import
torch
import
torch.nn.functional
as
Fu
from
pytorch3d.renderer
import
(
AlphaCompositor
,
NDCMultinomialRaysampler
,
PointsRasterizationSettings
,
PointsRasterizer
,
ray_bundle_to_ray_points
,
)
from
pytorch3d.renderer.cameras
import
CamerasBase
from
pytorch3d.structures
import
Pointclouds
def
get_rgbd_point_cloud
(
camera
:
CamerasBase
,
image_rgb
:
torch
.
Tensor
,
depth_map
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask_thr
:
float
=
0.5
,
mask_points
:
bool
=
True
,
)
->
Pointclouds
:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
pixel colors.
"""
imh
,
imw
=
image_rgb
.
shape
[
2
:]
# convert the depth maps to point clouds using the grid ray sampler
pts_3d
=
ray_bundle_to_ray_points
(
NDCMultinomialRaysampler
(
image_width
=
imw
,
image_height
=
imh
,
n_pts_per_ray
=
1
,
min_depth
=
1.0
,
max_depth
=
1.0
,
)(
camera
).
_replace
(
lengths
=
depth_map
[:,
0
,
...,
None
])
)
pts_mask
=
depth_map
>
0.0
if
mask
is
not
None
:
pts_mask
*=
mask
>
mask_thr
pts_mask
=
pts_mask
.
reshape
(
-
1
)
pts_3d
=
pts_3d
.
reshape
(
-
1
,
3
)[
pts_mask
]
pts_colors
=
torch
.
nn
.
functional
.
interpolate
(
image_rgb
,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
size
=
[
imh
,
imw
],
mode
=
"bilinear"
,
align_corners
=
False
,
)
pts_colors
=
pts_colors
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
-
1
,
3
)[
pts_mask
]
return
Pointclouds
(
points
=
pts_3d
[
None
],
features
=
pts_colors
[
None
])
def
render_point_cloud_pytorch3d
(
camera
,
point_cloud
,
render_size
:
Tuple
[
int
,
int
],
point_radius
:
float
=
0.03
,
topk
:
int
=
10
,
eps
:
float
=
1e-2
,
bg_color
=
None
,
**
kwargs
):
# feature dimension
featdim
=
point_cloud
.
features_packed
().
shape
[
-
1
]
# move to the camera coordinates; using identity cameras in the renderer
point_cloud
=
_transform_points
(
camera
,
point_cloud
,
eps
,
**
kwargs
)
camera_trivial
=
camera
.
clone
()
camera_trivial
.
R
[:]
=
torch
.
eye
(
3
)
camera_trivial
.
T
*=
0.0
rasterizer
=
PointsRasterizer
(
cameras
=
camera_trivial
,
raster_settings
=
PointsRasterizationSettings
(
image_size
=
render_size
,
radius
=
point_radius
,
points_per_pixel
=
topk
,
bin_size
=
64
if
int
(
max
(
render_size
))
>
1024
else
None
,
),
)
fragments
=
rasterizer
(
point_cloud
,
**
kwargs
)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r
=
rasterizer
.
raster_settings
.
radius
# set up the blending weights
dists2
=
fragments
.
dists
weights
=
1
-
dists2
/
(
r
*
r
)
ok
=
cast
(
torch
.
BoolTensor
,
(
fragments
.
idx
>=
0
)).
float
()
weights
=
weights
*
ok
fragments_prm
=
fragments
.
idx
.
long
().
permute
(
0
,
3
,
1
,
2
)
weights_prm
=
weights
.
permute
(
0
,
3
,
1
,
2
)
images
=
AlphaCompositor
()(
fragments_prm
,
weights_prm
,
point_cloud
.
features_packed
().
permute
(
1
,
0
),
background_color
=
bg_color
if
bg_color
is
not
None
else
[
0.0
]
*
featdim
,
**
kwargs
,
)
# get the depths ...
# weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
# cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
cumprod
=
torch
.
cumprod
(
1
-
weights
,
dim
=-
1
)
cumprod
=
torch
.
cat
((
torch
.
ones_like
(
cumprod
[...,
:
1
]),
cumprod
[...,
:
-
1
]),
dim
=-
1
)
depths
=
(
weights
*
cumprod
*
fragments
.
zbuf
).
sum
(
dim
=-
1
)
# add the rendering mask
render_mask
=
-
torch
.
prod
(
1.0
-
weights
,
dim
=-
1
)
+
1.0
# cat depths and render mask
rendered_blob
=
torch
.
cat
((
images
,
depths
[:,
None
],
render_mask
[:,
None
]),
dim
=
1
)
# reshape back
rendered_blob
=
Fu
.
interpolate
(
rendered_blob
,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got `Tuple[int,
# ...]`.
size
=
tuple
(
render_size
),
mode
=
"bilinear"
,
)
data_rendered
,
depth_rendered
,
render_mask
=
rendered_blob
.
split
(
[
rendered_blob
.
shape
[
1
]
-
2
,
1
,
1
],
dim
=
1
,
)
return
data_rendered
,
render_mask
,
depth_rendered
def
_signed_clamp
(
x
,
eps
):
sign
=
x
.
sign
()
+
(
x
==
0.0
).
type_as
(
x
)
x_clamp
=
sign
*
torch
.
clamp
(
x
.
abs
(),
eps
)
return
x_clamp
def
_transform_points
(
cameras
,
point_clouds
,
eps
,
**
kwargs
):
pts_world
=
point_clouds
.
points_padded
()
pts_view
=
cameras
.
get_world_to_view_transform
(
**
kwargs
).
transform_points
(
pts_world
,
eps
=
eps
)
# it is crucial to actually clamp the points as well ...
pts_view
=
torch
.
cat
(
(
pts_view
[...,
:
-
1
],
_signed_clamp
(
pts_view
[...,
-
1
:],
eps
)),
dim
=-
1
)
point_clouds
=
point_clouds
.
update_padded
(
pts_view
)
return
point_clouds
pytorch3d/implicitron/tools/rasterize_mc.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
,
Tuple
import
torch
from
pytorch3d.renderer
import
PerspectiveCameras
from
pytorch3d.structures
import
Pointclouds
from
.point_cloud_utils
import
render_point_cloud_pytorch3d
def
rasterize_mc_samples
(
xys
:
torch
.
Tensor
,
feats
:
torch
.
Tensor
,
image_size_hw
:
Tuple
[
int
,
int
],
radius
:
float
=
0.03
,
topk
:
int
=
5
,
masks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rasterizes Monte-Carlo sampled features back onto the image.
Specifically, the code uses the PyTorch3D point rasterizer to render
a z-flat point cloud composed of the xy MC locations and their features.
Args:
xys: B x N x 2 2D point locations in PyTorch3D NDC convention
feats: B x N x dim tensor containing per-point rendered features.
image_size_hw: Tuple[image_height, image_width] containing
the size of rasterized image.
radius: Rasterization point radius.
topk: The maximum z-buffer size for the PyTorch3D point cloud rasterizer.
masks: B x N x 1 tensor containing the alpha mask of the
rendered features.
"""
if
masks
is
None
:
masks
=
torch
.
ones_like
(
xys
[...,
:
1
])
feats
=
torch
.
cat
((
feats
,
masks
),
dim
=-
1
)
pointclouds
=
Pointclouds
(
points
=
torch
.
cat
([
xys
,
torch
.
ones_like
(
xys
[...,
:
1
])],
dim
=-
1
),
features
=
feats
,
)
data_rendered
,
render_mask
,
_
=
render_point_cloud_pytorch3d
(
PerspectiveCameras
(
device
=
feats
.
device
),
pointclouds
,
render_size
=
image_size_hw
,
point_radius
=
radius
,
topk
=
topk
,
)
data_rendered
,
masks_pt
=
data_rendered
.
split
(
[
data_rendered
.
shape
[
1
]
-
1
,
1
],
dim
=
1
)
render_mask
=
masks_pt
*
render_mask
return
data_rendered
,
render_mask
pytorch3d/implicitron/tools/stats.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
gzip
import
json
import
time
import
warnings
from
collections.abc
import
Iterable
from
itertools
import
cycle
import
matplotlib
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
matplotlib
import
colors
as
mcolors
from
pytorch3d.implicitron.tools.vis_utils
import
get_visdom_connection
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
history
=
[]
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
,
epoch
=
0
):
# make sure the history is of the same len as epoch
while
len
(
self
.
history
)
<=
epoch
:
self
.
history
.
append
([])
self
.
history
[
epoch
].
append
(
val
/
n
)
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
get_epoch_averages
(
self
,
epoch
=-
1
):
if
len
(
self
.
history
)
==
0
:
# no stats here
return
None
elif
epoch
==
-
1
:
return
[
(
float
(
np
.
array
(
x
).
mean
())
if
len
(
x
)
>
0
else
float
(
"NaN"
))
for
x
in
self
.
history
]
else
:
return
float
(
np
.
array
(
self
.
history
[
epoch
]).
mean
())
def
get_all_values
(
self
):
all_vals
=
[
np
.
array
(
x
)
for
x
in
self
.
history
]
all_vals
=
np
.
concatenate
(
all_vals
)
return
all_vals
def
get_epoch
(
self
):
return
len
(
self
.
history
)
@
staticmethod
def
from_json_str
(
json_str
):
self
=
AverageMeter
()
self
.
__dict__
.
update
(
json
.
loads
(
json_str
))
return
self
class
Stats
(
object
):
# TODO: update this with context manager
"""
stats logging object useful for gathering statistics of training a deep net in pytorch
Example:
# init stats structure that logs statistics 'objective' and 'top1e'
stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=nueral network)
dataloader = init_dataloader() # init a dataloader
for epoch in range(10):
# start of epoch -> call new_epoch
stats.new_epoch()
# iterate over batches
for batch in dataloader:
output = network(batch) # run and save into a dict of output variables "output"
# stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db
stats.update(output)
stats.print() # prints the averages over given epoch
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
"""
def
__init__
(
self
,
log_vars
,
verbose
=
False
,
epoch
=-
1
,
visdom_env
=
"main"
,
do_plot
=
True
,
plot_file
=
None
,
visdom_server
=
"http://localhost"
,
visdom_port
=
8097
,
):
self
.
verbose
=
verbose
self
.
log_vars
=
log_vars
self
.
visdom_env
=
visdom_env
self
.
visdom_server
=
visdom_server
self
.
visdom_port
=
visdom_port
self
.
plot_file
=
plot_file
self
.
do_plot
=
do_plot
self
.
hard_reset
(
epoch
=
epoch
)
@
staticmethod
def
from_json_str
(
json_str
):
self
=
Stats
([])
# load the global state
self
.
__dict__
.
update
(
json
.
loads
(
json_str
))
# recover the AverageMeters
for
stat_set
in
self
.
stats
:
self
.
stats
[
stat_set
]
=
{
log_var
:
AverageMeter
.
from_json_str
(
log_vals_json_str
)
for
log_var
,
log_vals_json_str
in
self
.
stats
[
stat_set
].
items
()
}
return
self
@
staticmethod
def
load
(
flpath
,
postfix
=
".jgz"
):
flpath
=
_get_postfixed_filename
(
flpath
,
postfix
)
with
gzip
.
open
(
flpath
,
"r"
)
as
fin
:
data
=
json
.
loads
(
fin
.
read
().
decode
(
"utf-8"
))
return
Stats
.
from_json_str
(
data
)
def
save
(
self
,
flpath
,
postfix
=
".jgz"
):
flpath
=
_get_postfixed_filename
(
flpath
,
postfix
)
# store into a gzipped-json
with
gzip
.
open
(
flpath
,
"w"
)
as
fout
:
fout
.
write
(
json
.
dumps
(
self
,
cls
=
StatsJSONEncoder
).
encode
(
"utf-8"
))
# some sugar to be used with "with stats:" at the beginning of the epoch
def
__enter__
(
self
):
if
self
.
do_plot
and
self
.
epoch
>=
0
:
self
.
plot_stats
(
self
.
visdom_env
)
self
.
new_epoch
()
def
__exit__
(
self
,
type
,
value
,
traceback
):
iserr
=
type
is
not
None
and
issubclass
(
type
,
Exception
)
iserr
=
iserr
or
(
type
is
KeyboardInterrupt
)
if
iserr
:
print
(
"error inside 'with' block"
)
return
if
self
.
do_plot
:
self
.
plot_stats
(
self
.
visdom_env
)
def
reset
(
self
):
# to be called after each epoch
stat_sets
=
list
(
self
.
stats
.
keys
())
if
self
.
verbose
:
print
(
"stats: epoch %d - reset"
%
self
.
epoch
)
self
.
it
=
{
k
:
-
1
for
k
in
stat_sets
}
for
stat_set
in
stat_sets
:
for
stat
in
self
.
stats
[
stat_set
]:
self
.
stats
[
stat_set
][
stat
].
reset
()
def
hard_reset
(
self
,
epoch
=-
1
):
# to be called during object __init__
self
.
epoch
=
epoch
if
self
.
verbose
:
print
(
"stats: epoch %d - hard reset"
%
self
.
epoch
)
self
.
stats
=
{}
# reset
self
.
reset
()
def
new_epoch
(
self
):
if
self
.
verbose
:
print
(
"stats: new epoch %d"
%
(
self
.
epoch
+
1
))
self
.
epoch
+=
1
self
.
reset
()
# zero the stats + increase epoch counter
def
gather_value
(
self
,
val
):
if
isinstance
(
val
,
(
float
,
int
)):
val
=
float
(
val
)
else
:
val
=
val
.
data
.
cpu
().
numpy
()
val
=
float
(
val
.
sum
())
return
val
def
add_log_vars
(
self
,
added_log_vars
,
verbose
=
True
):
for
add_log_var
in
added_log_vars
:
if
add_log_var
not
in
self
.
stats
:
if
verbose
:
print
(
f
"Adding
{
add_log_var
}
"
)
self
.
log_vars
.
append
(
add_log_var
)
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
def
update
(
self
,
preds
,
time_start
=
None
,
freeze_iter
=
False
,
stat_set
=
"train"
):
if
self
.
epoch
==
-
1
:
# uninitialized
print
(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
)
self
.
new_epoch
()
if
stat_set
not
in
self
.
stats
:
self
.
stats
[
stat_set
]
=
{}
self
.
it
[
stat_set
]
=
-
1
if
not
freeze_iter
:
self
.
it
[
stat_set
]
+=
1
epoch
=
self
.
epoch
it
=
self
.
it
[
stat_set
]
for
stat
in
self
.
log_vars
:
if
stat
not
in
self
.
stats
[
stat_set
]:
self
.
stats
[
stat_set
][
stat
]
=
AverageMeter
()
if
stat
==
"sec/it"
:
# compute speed
if
time_start
is
None
:
elapsed
=
0.0
else
:
elapsed
=
time
.
time
()
-
time_start
time_per_it
=
float
(
elapsed
)
/
float
(
it
+
1
)
val
=
time_per_it
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
else
:
if
stat
in
preds
:
try
:
val
=
self
.
gather_value
(
preds
[
stat
])
except
KeyError
:
raise
ValueError
(
"could not extract prediction %s
\
from the prediction dictionary"
%
stat
)
else
:
val
=
None
if
val
is
not
None
:
self
.
stats
[
stat_set
][
stat
].
update
(
val
,
epoch
=
epoch
,
n
=
1
)
def
get_epoch_averages
(
self
,
epoch
=
None
):
stat_sets
=
list
(
self
.
stats
.
keys
())
if
epoch
is
None
:
epoch
=
self
.
epoch
if
epoch
==
-
1
:
epoch
=
list
(
range
(
self
.
epoch
))
outvals
=
{}
for
stat_set
in
stat_sets
:
outvals
[
stat_set
]
=
{
"epoch"
:
epoch
,
"it"
:
self
.
it
[
stat_set
],
"epoch_max"
:
self
.
epoch
,
}
for
stat
in
self
.
stats
[
stat_set
].
keys
():
if
self
.
stats
[
stat_set
][
stat
].
count
==
0
:
continue
if
isinstance
(
epoch
,
Iterable
):
avgs
=
self
.
stats
[
stat_set
][
stat
].
get_epoch_averages
()
avgs
=
[
avgs
[
e
]
for
e
in
epoch
]
else
:
avgs
=
self
.
stats
[
stat_set
][
stat
].
get_epoch_averages
(
epoch
=
epoch
)
outvals
[
stat_set
][
stat
]
=
avgs
return
outvals
def
print
(
self
,
max_it
=
None
,
stat_set
=
"train"
,
vars_print
=
None
,
get_str
=
False
,
skip_nan
=
False
,
stat_format
=
lambda
s
:
s
.
replace
(
"loss_"
,
""
).
replace
(
"prev_stage_"
,
"ps_"
),
):
epoch
=
self
.
epoch
stats
=
self
.
stats
str_out
=
""
it
=
self
.
it
[
stat_set
]
stat_str
=
""
stats_print
=
sorted
(
stats
[
stat_set
].
keys
())
for
stat
in
stats_print
:
if
stats
[
stat_set
][
stat
].
count
==
0
:
continue
if
skip_nan
and
not
np
.
isfinite
(
stats
[
stat_set
][
stat
].
avg
):
continue
stat_str
+=
" {0:.12}: {1:1.3f} |"
.
format
(
stat_format
(
stat
),
stats
[
stat_set
][
stat
].
avg
)
head_str
=
"[%s] | epoch %3d | it %5d"
%
(
stat_set
,
epoch
,
it
)
if
max_it
:
head_str
+=
"/ %d"
%
max_it
str_out
=
"%s | %s"
%
(
head_str
,
stat_str
)
if
get_str
:
return
str_out
else
:
print
(
str_out
)
def
plot_stats
(
self
,
visdom_env
=
None
,
plot_file
=
None
,
visdom_server
=
None
,
visdom_port
=
None
):
# use the cached visdom env if none supplied
if
visdom_env
is
None
:
visdom_env
=
self
.
visdom_env
if
visdom_server
is
None
:
visdom_server
=
self
.
visdom_server
if
visdom_port
is
None
:
visdom_port
=
self
.
visdom_port
if
plot_file
is
None
:
plot_file
=
self
.
plot_file
stat_sets
=
list
(
self
.
stats
.
keys
())
print
(
"printing charts to visdom env '%s' (%s:%d)"
%
(
visdom_env
,
visdom_server
,
visdom_port
)
)
novisdom
=
False
viz
=
get_visdom_connection
(
server
=
visdom_server
,
port
=
visdom_port
)
if
not
viz
.
check_connection
():
print
(
"no visdom server! -> skipping visdom plots"
)
novisdom
=
True
lines
=
[]
# plot metrics
if
not
novisdom
:
viz
.
close
(
env
=
visdom_env
,
win
=
None
)
for
stat
in
self
.
log_vars
:
vals
=
[]
stat_sets_now
=
[]
for
stat_set
in
stat_sets
:
val
=
self
.
stats
[
stat_set
][
stat
].
get_epoch_averages
()
if
val
is
None
:
continue
else
:
val
=
np
.
array
(
val
).
reshape
(
-
1
)
stat_sets_now
.
append
(
stat_set
)
vals
.
append
(
val
)
if
len
(
vals
)
==
0
:
continue
lines
.
append
((
stat_sets_now
,
stat
,
vals
))
if
not
novisdom
:
for
tmodes
,
stat
,
vals
in
lines
:
title
=
"%s"
%
stat
opts
=
{
"title"
:
title
,
"legend"
:
list
(
tmodes
)}
for
i
,
(
tmode
,
val
)
in
enumerate
(
zip
(
tmodes
,
vals
)):
update
=
"append"
if
i
>
0
else
None
valid
=
np
.
where
(
np
.
isfinite
(
val
))[
0
]
if
len
(
valid
)
==
0
:
continue
x
=
np
.
arange
(
len
(
val
))
viz
.
line
(
Y
=
val
[
valid
],
X
=
x
[
valid
],
env
=
visdom_env
,
opts
=
opts
,
win
=
f
"stat_plot_
{
title
}
"
,
name
=
tmode
,
update
=
update
,
)
if
plot_file
:
print
(
"exporting stats to %s"
%
plot_file
)
ncol
=
3
nrow
=
int
(
np
.
ceil
(
float
(
len
(
lines
))
/
ncol
))
matplotlib
.
rcParams
.
update
({
"font.size"
:
5
})
color
=
cycle
(
plt
.
cm
.
tab10
(
np
.
linspace
(
0
,
1
,
10
)))
fig
=
plt
.
figure
(
1
)
plt
.
clf
()
for
idx
,
(
tmodes
,
stat
,
vals
)
in
enumerate
(
lines
):
c
=
next
(
color
)
plt
.
subplot
(
nrow
,
ncol
,
idx
+
1
)
plt
.
gca
()
for
vali
,
vals_
in
enumerate
(
vals
):
c_
=
c
*
(
1.0
-
float
(
vali
)
*
0.3
)
valid
=
np
.
where
(
np
.
isfinite
(
vals_
))[
0
]
if
len
(
valid
)
==
0
:
continue
x
=
np
.
arange
(
len
(
vals_
))
plt
.
plot
(
x
[
valid
],
vals_
[
valid
],
c
=
c_
,
linewidth
=
1
)
plt
.
ylabel
(
stat
)
plt
.
xlabel
(
"epoch"
)
plt
.
gca
().
yaxis
.
label
.
set_color
(
c
[
0
:
3
]
*
0.75
)
plt
.
legend
(
tmodes
)
gcolor
=
np
.
array
(
mcolors
.
to_rgba
(
"lightgray"
))
plt
.
grid
(
b
=
True
,
which
=
"major"
,
color
=
gcolor
,
linestyle
=
"-"
,
linewidth
=
0.4
)
plt
.
grid
(
b
=
True
,
which
=
"minor"
,
color
=
gcolor
,
linestyle
=
"--"
,
linewidth
=
0.2
)
plt
.
minorticks_on
()
plt
.
tight_layout
()
plt
.
show
()
try
:
fig
.
savefig
(
plot_file
)
except
PermissionError
:
warnings
.
warn
(
"Cant dump stats due to insufficient permissions!"
)
def
synchronize_logged_vars
(
self
,
log_vars
,
default_val
=
float
(
"NaN"
),
verbose
=
True
):
stat_sets
=
list
(
self
.
stats
.
keys
())
# remove the additional log_vars
for
stat_set
in
stat_sets
:
for
stat
in
self
.
stats
[
stat_set
].
keys
():
if
stat
not
in
log_vars
:
print
(
"additional stat %s:%s -> removing"
%
(
stat_set
,
stat
))
self
.
stats
[
stat_set
]
=
{
stat
:
v
for
stat
,
v
in
self
.
stats
[
stat_set
].
items
()
if
stat
in
log_vars
}
self
.
log_vars
=
log_vars
# !!!
for
stat_set
in
stat_sets
:
reference_stat
=
list
(
self
.
stats
[
stat_set
].
keys
())[
0
]
for
stat
in
log_vars
:
if
stat
not
in
self
.
stats
[
stat_set
]:
if
verbose
:
print
(
"missing stat %s:%s -> filling with default values (%1.2f)"
%
(
stat_set
,
stat
,
default_val
)
)
elif
len
(
self
.
stats
[
stat_set
][
stat
].
history
)
!=
self
.
epoch
+
1
:
h
=
self
.
stats
[
stat_set
][
stat
].
history
if
len
(
h
)
==
0
:
# just never updated stat ... skip
continue
else
:
if
verbose
:
print
(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
%
(
stat_set
,
stat
,
default_val
)
)
else
:
continue
self
.
stats
[
stat_set
][
stat
]
=
AverageMeter
()
self
.
stats
[
stat_set
][
stat
].
reset
()
lastep
=
self
.
epoch
+
1
for
ep
in
range
(
lastep
):
self
.
stats
[
stat_set
][
stat
].
update
(
default_val
,
n
=
1
,
epoch
=
ep
)
epoch_self
=
self
.
stats
[
stat_set
][
reference_stat
].
get_epoch
()
epoch_generated
=
self
.
stats
[
stat_set
][
stat
].
get_epoch
()
assert
(
epoch_self
==
epoch_generated
),
"bad epoch of synchronized log_var! %d vs %d"
%
(
epoch_self
,
epoch_generated
,
)
class
StatsJSONEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
o
):
if
isinstance
(
o
,
(
AverageMeter
,
Stats
)):
enc
=
self
.
encode
(
o
.
__dict__
)
return
enc
else
:
raise
TypeError
(
f
"Object of type
{
o
.
__class__
.
__name__
}
"
f
"is not JSON serializable"
)
def
_get_postfixed_filename
(
fl
,
postfix
):
return
fl
if
fl
.
endswith
(
postfix
)
else
fl
+
postfix
pytorch3d/implicitron/tools/utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
collections
import
dataclasses
import
time
from
contextlib
import
contextmanager
from
typing
import
Any
,
Callable
,
Dict
import
torch
@
contextmanager
def
evaluating
(
net
:
torch
.
nn
.
Module
):
"""Temporarily switch to evaluation mode."""
istrain
=
net
.
training
try
:
net
.
eval
()
yield
net
finally
:
if
istrain
:
net
.
train
()
def
try_to_cuda
(
t
:
Any
)
->
Any
:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try
:
t
=
t
.
cuda
()
except
AttributeError
:
pass
return
t
def
try_to_cpu
(
t
:
Any
)
->
Any
:
"""
Try to move the input variable `t` to a cpu device.
Args:
t: Input.
Returns:
t_cpu: `t` moved to a cpu device, if supported.
"""
try
:
t
=
t
.
cpu
()
except
AttributeError
:
pass
return
t
def
dict_to_cuda
(
batch
:
Dict
[
Any
,
Any
])
->
Dict
[
Any
,
Any
]:
"""
Move all values in a dictionary to cuda if supported.
Args:
batch: Input dict.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
return
{
k
:
try_to_cuda
(
v
)
for
k
,
v
in
batch
.
items
()}
def
dict_to_cpu
(
batch
):
"""
Move all values in a dictionary to cpu if supported.
Args:
batch: Input dict.
Returns:
batch_cpu: `batch` moved to a cpu device, if supported.
"""
return
{
k
:
try_to_cpu
(
v
)
for
k
,
v
in
batch
.
items
()}
def
dataclass_to_cuda_
(
obj
):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for
f
in
dataclasses
.
fields
(
obj
):
setattr
(
obj
,
f
.
name
,
try_to_cuda
(
getattr
(
obj
,
f
.
name
)))
return
obj
def
dataclass_to_cpu_
(
obj
):
"""
Move all contents of a dataclass to cpu inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cpu device, if supported.
"""
for
f
in
dataclasses
.
fields
(
obj
):
setattr
(
obj
,
f
.
name
,
try_to_cpu
(
getattr
(
obj
,
f
.
name
)))
return
obj
# TODO: test it
def
cat_dataclass
(
batch
,
tensor_collator
:
Callable
):
"""
Concatenate all fields of a list of dataclasses `batch` to a single
dataclass object using `tensor_collator`.
Args:
batch: Input list of dataclasses.
Returns:
concatenated_batch: All elements of `batch` concatenated to a single
dataclass object.
tensor_collator: The function used to concatenate tensor fields.
"""
elem
=
batch
[
0
]
collated
=
{}
for
f
in
dataclasses
.
fields
(
elem
):
elem_f
=
getattr
(
elem
,
f
.
name
)
if
elem_f
is
None
:
collated
[
f
.
name
]
=
None
elif
torch
.
is_tensor
(
elem_f
):
collated
[
f
.
name
]
=
tensor_collator
([
getattr
(
e
,
f
.
name
)
for
e
in
batch
])
elif
dataclasses
.
is_dataclass
(
elem_f
):
collated
[
f
.
name
]
=
cat_dataclass
(
[
getattr
(
e
,
f
.
name
)
for
e
in
batch
],
tensor_collator
)
elif
isinstance
(
elem_f
,
collections
.
abc
.
Mapping
):
collated
[
f
.
name
]
=
{
k
:
tensor_collator
([
getattr
(
e
,
f
.
name
)[
k
]
for
e
in
batch
])
if
elem_f
[
k
]
is
not
None
else
None
for
k
in
elem_f
}
else
:
raise
ValueError
(
"Unsupported field type for concatenation"
)
return
type
(
elem
)(
**
collated
)
class
Timer
:
"""
A simple class for timing execution.
Example:
```
with Timer():
print("This print statement is timed.")
```
"""
def
__init__
(
self
,
name
=
"timer"
,
quiet
=
False
):
self
.
name
=
name
self
.
quiet
=
quiet
def
__enter__
(
self
):
self
.
start
=
time
.
time
()
return
self
def
__exit__
(
self
,
*
args
):
self
.
end
=
time
.
time
()
self
.
interval
=
self
.
end
-
self
.
start
if
not
self
.
quiet
:
print
(
"%20s: %1.6f sec"
%
(
self
.
name
,
self
.
interval
))
pytorch3d/implicitron/tools/video_writer.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
shutil
import
tempfile
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
import
matplotlib
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
PIL
import
Image
matplotlib
.
use
(
"Agg"
)
class
VideoWriter
:
"""
A class for exporting videos.
"""
def
__init__
(
self
,
cache_dir
:
Optional
[
str
]
=
None
,
ffmpeg_bin
:
str
=
"ffmpeg"
,
out_path
:
str
=
"/tmp/video.mp4"
,
fps
:
int
=
20
,
output_format
:
str
=
"visdom"
,
rmdir_allowed
:
bool
=
False
,
**
kwargs
,
):
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
a temporary directory will be used.
ffmpeg_bin: The path to an `ffmpeg` executable.
out_path: The path to the output video.
fps: The speed of the generated video in frames-per-second.
output_format: Format of the output video. Currently only `"visdom"`
is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty.
"""
self
.
rmdir_allowed
=
rmdir_allowed
self
.
output_format
=
output_format
self
.
fps
=
fps
self
.
out_path
=
out_path
self
.
cache_dir
=
cache_dir
self
.
ffmpeg_bin
=
ffmpeg_bin
self
.
frames
=
[]
self
.
regexp
=
"frame_%08d.png"
self
.
frame_num
=
0
if
self
.
cache_dir
is
not
None
:
self
.
tmp_dir
=
None
if
os
.
path
.
isdir
(
self
.
cache_dir
):
if
rmdir_allowed
:
shutil
.
rmtree
(
self
.
cache_dir
)
else
:
warnings
.
warn
(
f
"Warning: cache directory not empty (
{
self
.
cache_dir
}
)."
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
else
:
self
.
tmp_dir
=
tempfile
.
TemporaryDirectory
()
self
.
cache_dir
=
self
.
tmp_dir
.
name
def
write_frame
(
self
,
frame
:
Union
[
matplotlib
.
figure
.
Figure
,
np
.
ndarray
,
Image
.
Image
,
str
],
resize
:
Optional
[
Union
[
float
,
Tuple
[
int
,
int
]]]
=
None
,
):
"""
Write a frame to the video.
Args:
frame: An object containing the frame image.
resize: Either a floating defining the image rescaling factor
or a 2-tuple defining the size of the output image.
"""
outfile
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
regexp
%
self
.
frame_num
)
if
isinstance
(
frame
,
matplotlib
.
figure
.
Figure
):
plt
.
savefig
(
outfile
)
im
=
Image
.
open
(
outfile
)
elif
isinstance
(
frame
,
np
.
ndarray
):
if
frame
.
dtype
in
(
np
.
float64
,
np
.
float32
,
float
):
frame
=
(
np
.
transpose
(
frame
,
(
1
,
2
,
0
))
*
255.0
).
astype
(
np
.
uint8
)
im
=
Image
.
fromarray
(
frame
)
elif
isinstance
(
frame
,
Image
.
Image
):
im
=
frame
elif
isinstance
(
frame
,
str
):
im
=
Image
.
open
(
frame
).
convert
(
"RGB"
)
else
:
raise
ValueError
(
"Cant convert type %s"
%
str
(
type
(
frame
)))
if
im
is
not
None
:
if
resize
is
not
None
:
if
isinstance
(
resize
,
float
):
resize
=
[
int
(
resize
*
s
)
for
s
in
im
.
size
]
else
:
resize
=
im
.
size
# make sure size is divisible by 2
resize
=
tuple
([
resize
[
i
]
+
resize
[
i
]
%
2
for
i
in
(
0
,
1
)])
im
=
im
.
resize
(
resize
,
Image
.
ANTIALIAS
)
im
.
save
(
outfile
)
self
.
frames
.
append
(
outfile
)
self
.
frame_num
+=
1
def
get_video
(
self
,
quiet
:
bool
=
True
):
"""
Generate the video from the written frames.
Args:
quiet: If `True`, suppresses logging messages.
Returns:
video_path: The path to the generated video.
"""
regexp
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
regexp
)
if
self
.
output_format
==
"visdom"
:
# works for ppt too
ffmcmd_
=
(
"%s -r %d -i %s -vcodec h264 -f mp4
\
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
%
(
self
.
ffmpeg_bin
,
self
.
fps
,
regexp
,
self
.
out_path
)
)
else
:
raise
ValueError
(
"no such output type %s"
%
str
(
self
.
output_format
))
if
quiet
:
ffmcmd_
+=
" > /dev/null 2>&1"
else
:
print
(
ffmcmd_
)
os
.
system
(
ffmcmd_
)
return
self
.
out_path
def
__del__
(
self
):
if
self
.
tmp_dir
is
not
None
:
self
.
tmp_dir
.
cleanup
()
pytorch3d/implicitron/tools/vis_utils.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Dict
,
List
import
torch
from
visdom
import
Visdom
def
get_visdom_env
(
cfg
):
"""
Parse out visdom environment name from the input config.
Args:
cfg: The global config file.
Returns:
visdom_env: The name of the visdom environment.
"""
if
len
(
cfg
.
visdom_env
)
==
0
:
visdom_env
=
cfg
.
exp_dir
.
split
(
"/"
)[
-
1
]
else
:
visdom_env
=
cfg
.
visdom_env
return
visdom_env
# TODO: a proper singleton
_viz_singleton
=
None
def
get_visdom_connection
(
server
:
str
=
"http://localhost"
,
port
:
int
=
8097
,
)
->
Visdom
:
"""
Obtain a connection to a visdom server.
Args:
server: Server address.
port: Server port.
Returns:
connection: The connection object.
"""
global
_viz_singleton
if
_viz_singleton
is
None
:
_viz_singleton
=
Visdom
(
server
=
server
,
port
=
port
)
return
_viz_singleton
def
visualize_basics
(
viz
:
Visdom
,
preds
:
Dict
[
str
,
Any
],
visdom_env_imgs
:
str
,
title
:
str
=
""
,
visualize_preds_keys
:
List
[
str
]
=
[
"image_rgb"
,
"images_render"
,
"fg_probability"
,
"masks_render"
,
"depths_render"
,
"depth_map"
,
],
store_history
:
bool
=
False
,
)
->
None
:
"""
Visualize basic outputs of a `GenericModel` to visdom.
Args:
viz: The visdom object.
preds: A dictionary containing `GenericModel` outputs.
visdom_env_imgs: Target visdom environment name.
title: The title of produced visdom window.
visualize_preds_keys: The list of keys of `preds` for visualization.
store_history: Store the history buffer in visdom windows.
"""
imout
=
{}
for
k
in
visualize_preds_keys
:
if
k
not
in
preds
or
preds
[
k
]
is
None
:
print
(
f
"cant show
{
k
}
"
)
continue
v
=
preds
[
k
].
cpu
().
detach
().
clone
()
if
k
.
startswith
(
"depth"
):
# divide by 95th percentile
normfac
=
(
v
.
view
(
v
.
shape
[
0
],
-
1
)
.
topk
(
k
=
int
(
0.05
*
(
v
.
numel
()
//
v
.
shape
[
0
])),
dim
=-
1
)
.
values
[:,
-
1
]
)
v
=
v
/
normfac
[:,
None
,
None
,
None
].
clamp
(
1e-4
)
if
v
.
shape
[
1
]
==
1
:
v
=
v
.
repeat
(
1
,
3
,
1
,
1
)
v
=
torch
.
nn
.
functional
.
interpolate
(
v
,
# pyre-fixme[6]: Expected `Optional[typing.List[float]]` for 2nd param
# but got `float`.
scale_factor
=
(
600.0
if
(
"_eval"
in
visdom_env_imgs
and
k
in
(
"images_render"
,
"depths_render"
)
)
else
200.0
)
/
v
.
shape
[
2
],
mode
=
"bilinear"
,
)
imout
[
k
]
=
v
# TODO: handle errors on the outside
try
:
imout
=
{
"all"
:
torch
.
cat
(
list
(
imout
.
values
()),
dim
=
2
)}
except
:
print
(
"cant cat!"
)
for
k
,
v
in
imout
.
items
():
viz
.
images
(
v
.
clamp
(
0.0
,
1.0
),
win
=
k
,
env
=
visdom_env_imgs
,
opts
=
{
"title"
:
title
+
"_"
+
k
,
"store_history"
:
store_history
},
)
def
make_depth_image
(
depths
:
torch
.
Tensor
,
masks
:
torch
.
Tensor
,
max_quantile
:
float
=
0.98
,
min_quantile
:
float
=
0.02
,
min_out_depth
:
float
=
0.1
,
max_out_depth
:
float
=
0.9
,
)
->
torch
.
Tensor
:
"""
Convert a batch of depth maps to a grayscale image.
Args:
depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks.
max_quantile: The quantile of the input depth values which will
be mapped to `max_out_depth`.
min_quantile: The quantile of the input depth values which will
be mapped to `min_out_depth`.
min_out_depth: The minimal value in each depth map will be assigned this color.
max_out_depth: The maximal value in each depth map will be assigned this color.
Returns:
depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale
depth images.
"""
normfacs
=
[]
for
d
,
m
in
zip
(
depths
,
masks
):
ok
=
(
d
.
view
(
-
1
)
>
1e-6
)
*
(
m
.
view
(
-
1
)
>
0.5
)
if
ok
.
sum
()
<=
1
:
print
(
"empty depth!"
)
normfacs
.
append
(
torch
.
zeros
(
2
).
type_as
(
depths
))
continue
dok
=
d
.
view
(
-
1
)[
ok
].
view
(
-
1
)
_maxk
=
max
(
int
(
round
((
1
-
max_quantile
)
*
(
dok
.
numel
()))),
1
)
_mink
=
max
(
int
(
round
(
min_quantile
*
(
dok
.
numel
()))),
1
)
normfac_max
=
dok
.
topk
(
k
=
_maxk
,
dim
=-
1
).
values
[
-
1
]
normfac_min
=
dok
.
topk
(
k
=
_mink
,
dim
=-
1
,
largest
=
False
).
values
[
-
1
]
normfacs
.
append
(
torch
.
stack
([
normfac_min
,
normfac_max
]))
normfacs
=
torch
.
stack
(
normfacs
)
_min
,
_max
=
(
normfacs
[:,
0
].
view
(
-
1
,
1
,
1
,
1
),
normfacs
[:,
1
].
view
(
-
1
,
1
,
1
,
1
))
depths
=
(
depths
-
_min
)
/
(
_max
-
_min
).
clamp
(
1e-4
)
depths
=
(
(
depths
*
(
max_out_depth
-
min_out_depth
)
+
min_out_depth
)
*
masks
.
float
()
).
clamp
(
0.0
,
1.0
)
return
depths
tests/implicitron/__init__.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
tests/implicitron/common_resources.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
contextlib
import
logging
import
os
import
tempfile
import
unittest
from
pathlib
import
Path
from
typing
import
Generator
,
Tuple
from
zipfile
import
ZipFile
from
iopath.common.file_io
import
PathManager
@
contextlib
.
contextmanager
def
get_skateboard_data
(
avoid_manifold
:
bool
=
False
,
silence_logs
:
bool
=
False
)
->
Generator
[
Tuple
[
str
,
PathManager
],
None
,
None
]:
"""
Context manager for accessing Co3D dataset by tests, at least for
the first 5 skateboards. Internally, we want this to exercise the
normal way to access the data directly manifold, but on an RE
worker this is impossible so we use a workaround.
Args:
avoid_manifold: Use the method used by RE workers even locally.
silence_logs: Whether to reduce log output from iopath library.
Yields:
dataset_root: (str) path to dataset root.
path_manager: path_manager to access it with.
"""
path_manager
=
PathManager
()
if
silence_logs
:
logging
.
getLogger
(
"iopath.fb.manifold"
).
setLevel
(
logging
.
CRITICAL
)
logging
.
getLogger
(
"iopath.common.file_io"
).
setLevel
(
logging
.
CRITICAL
)
if
not
os
.
environ
.
get
(
"FB_TEST"
,
False
):
if
os
.
getenv
(
"FAIR_ENV_CLUSTER"
,
""
)
==
""
:
raise
unittest
.
SkipTest
(
"Unknown environment. Data not available."
)
yield
"/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18"
,
path_manager
elif
avoid_manifold
or
os
.
environ
.
get
(
"INSIDE_RE_WORKER"
,
False
):
from
libfb.py.parutil
import
get_file_path
par_path
=
"skateboard_first_5"
source
=
get_file_path
(
par_path
)
assert
Path
(
source
).
is_file
()
with
tempfile
.
TemporaryDirectory
()
as
dest
:
with
ZipFile
(
source
)
as
f
:
f
.
extractall
(
dest
)
yield
os
.
path
.
join
(
dest
,
"extracted"
),
path_manager
else
:
from
iopath.fb.manifold
import
ManifoldPathHandler
path_manager
.
register_handler
(
ManifoldPathHandler
())
yield
"manifold://co3d/tree/extracted"
,
path_manager
def
provide_lpips_vgg
():
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16
"""
# In OSS, torchvision looks for vgg16 weights in
# https://download.pytorch.org/models/vgg16-397923af.pth
# Inside fbcode, this is replaced by asking iopath for
# manifold://torchvision/tree/models/vgg16-397923af.pth
# (the code for this replacement is in
# fbcode/pytorch/vision/fb/_internally_replaced_utils.py )
#
# iopath does this by looking for the file at the cache location
# and if it is not there getting it from manifold.
# (the code for this is in
# fbcode/fair_infra/data/iopath/iopath/fb/manifold.py )
#
# On the remote execution worker, manifold is inaccessible.
# We solve this by making the cached file available before iopath
# looks.
#
# By default the cache location is
# ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# But we can't write to the home directory on the RE worker.
# We define FVCORE_CACHE to change the cache location to
# iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# (Without it, manifold caches in unstable temporary locations on RE.)
#
# The file we want has been copied from
# tree/models/vgg16-397923af.pth in the torchvision bucket
# to
# tree/testing/vgg16-397923af.pth in the co3d bucket
# and the TARGETS file copies it somewhere in the PAR which we
# recover with get_file_path.
# (It can't copy straight to a nested location, see
# https://fb.workplace.com/groups/askbuck/posts/2644615728920359/)
# Here we symlink it to the new cache location.
if
os
.
environ
.
get
(
"INSIDE_RE_WORKER"
)
is
not
None
:
from
libfb.py.parutil
import
get_file_path
os
.
environ
[
"FVCORE_CACHE"
]
=
"iopath_cache"
par_path
=
"vgg_weights_for_lpips"
source
=
Path
(
get_file_path
(
par_path
))
assert
source
.
is_file
()
dest
=
Path
(
"iopath_cache/manifold_cache/tree/models"
)
if
not
dest
.
exists
():
dest
.
mkdir
(
parents
=
True
)
(
dest
/
"vgg16-397923af.pth"
).
symlink_to
(
source
)
tests/implicitron/data/overrides.yaml
0 → 100644
View file @
cdd2142d
mask_images
:
true
mask_depths
:
true
render_image_width
:
400
render_image_height
:
400
mask_threshold
:
0.5
output_rasterized_mc
:
false
bg_color
:
-
0.0
-
0.0
-
0.0
view_pool
:
false
num_passes
:
1
chunk_size_grid
:
4096
render_features_dimensions
:
3
tqdm_trigger_threshold
:
16
n_train_target_views
:
1
sampling_mode_training
:
mask_sample
sampling_mode_evaluation
:
full_grid
renderer_class_type
:
LSTMRenderer
feature_aggregator_class_type
:
AngleWeightedIdentityFeatureAggregator
implicit_function_class_type
:
IdrFeatureField
loss_weights
:
loss_rgb_mse
:
1.0
loss_prev_stage_rgb_mse
:
1.0
loss_mask_bce
:
0.0
loss_prev_stage_mask_bce
:
0.0
log_vars
:
-
loss_rgb_psnr_fg
-
loss_rgb_psnr
-
loss_rgb_mse
-
loss_rgb_huber
-
loss_depth_abs
-
loss_depth_abs_fg
-
loss_mask_neg_iou
-
loss_mask_bce
-
loss_mask_beta_prior
-
loss_eikonal
-
loss_density_tv
-
loss_depth_neg_penalty
-
loss_autodecoder_norm
-
loss_prev_stage_rgb_mse
-
loss_prev_stage_rgb_psnr_fg
-
loss_prev_stage_rgb_psnr
-
loss_prev_stage_mask_bce
-
objective
-
epoch
-
sec/it
sequence_autodecoder_args
:
encoding_dim
:
0
n_instances
:
0
init_scale
:
1.0
ignore_input
:
false
raysampler_args
:
image_width
:
400
image_height
:
400
scene_center
:
-
0.0
-
0.0
-
0.0
scene_extent
:
0.0
sampling_mode_training
:
mask_sample
sampling_mode_evaluation
:
full_grid
n_pts_per_ray_training
:
64
n_pts_per_ray_evaluation
:
64
n_rays_per_image_sampled_from_mask
:
1024
min_depth
:
0.1
max_depth
:
8.0
stratified_point_sampling_training
:
true
stratified_point_sampling_evaluation
:
false
renderer_LSTMRenderer_args
:
num_raymarch_steps
:
10
init_depth
:
17.0
init_depth_noise_std
:
0.0005
hidden_size
:
16
n_feature_channels
:
256
verbose
:
false
image_feature_extractor_args
:
name
:
resnet34
pretrained
:
true
stages
:
-
1
-
2
-
3
-
4
normalize_image
:
true
image_rescale
:
0.16
first_max_pool
:
true
proj_dim
:
32
l2_norm
:
true
add_masks
:
true
add_images
:
true
global_average_pool
:
false
feature_rescale
:
1.0
view_sampler_args
:
masked_sampling
:
false
sampling_mode
:
bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args
:
exclude_target_view
:
true
exclude_target_view_mask_features
:
true
concatenate_output
:
true
weight_by_ray_angle_gamma
:
1.0
min_ray_angle_weight
:
0.1
implicit_function_IdrFeatureField_args
:
feature_vector_size
:
3
d_in
:
3
d_out
:
1
dims
:
-
512
-
512
-
512
-
512
-
512
-
512
-
512
-
512
geometric_init
:
true
bias
:
1.0
skip_in
:
[]
weight_norm
:
true
n_harmonic_functions_xyz
:
0
pooled_feature_dim
:
0
encoding_dim
:
0
tests/implicitron/test_batch_sampler.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
unittest
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
pytorch3d.implicitron.dataset.scene_batch_sampler
import
SceneBatchSampler
@
dataclass
class
MockFrameAnnotation
:
frame_number
:
int
frame_timestamp
:
float
=
0.0
class
MockDataset
:
def
__init__
(
self
,
num_seq
,
max_frame_gap
=
1
):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self
.
seq_annots
=
{
f
"seq_
{
i
}
"
:
None
for
i
in
range
(
num_seq
)}
self
.
seq_to_idx
=
{
f
"seq_
{
i
}
"
:
list
(
range
(
i
*
10
,
i
*
10
+
10
))
for
i
in
range
(
num_seq
)
}
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
# where n - 4 == max_frame_gap
frame_nos
=
list
(
range
(
5
))
+
list
(
range
(
4
+
max_frame_gap
,
9
+
max_frame_gap
))
self
.
frame_annots
=
[
{
"frame_annotation"
:
MockFrameAnnotation
(
no
)}
for
no
in
frame_nos
*
num_seq
]
def
get_frame_numbers_and_timestamps
(
self
,
idxs
):
out
=
[]
for
idx
in
idxs
:
frame_annotation
=
self
.
frame_annots
[
idx
][
"frame_annotation"
]
out
.
append
(
(
frame_annotation
.
frame_number
,
frame_annotation
.
frame_timestamp
)
)
return
out
class
TestSceneBatchSampler
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dataset_overfit
=
MockDataset
(
1
)
def
test_overfit
(
self
):
num_batches
=
3
batch_size
=
10
sampler
=
SceneBatchSampler
(
self
.
dataset_overfit
,
batch_size
=
batch_size
,
num_batches
=
num_batches
,
images_per_seq_options
=
[
10
],
# will try to sample batch_size anyway
)
self
.
assertEqual
(
len
(
sampler
),
num_batches
)
it
=
iter
(
sampler
)
for
_
in
range
(
num_batches
):
batch
=
next
(
it
)
self
.
assertIsNotNone
(
batch
)
self
.
assertEqual
(
len
(
batch
),
batch_size
)
# true for our examples
self
.
assertTrue
(
all
(
idx
//
10
==
0
for
idx
in
batch
))
with
self
.
assertRaises
(
StopIteration
):
batch
=
next
(
it
)
def
test_multiseq
(
self
):
for
ips_options
in
[[
10
],
[
2
],
[
3
],
[
2
,
3
,
4
]]:
for
sample_consecutive_frames
in
[
True
,
False
]:
for
consecutive_frames_max_gap
in
[
0
,
1
,
3
]:
self
.
_test_multiseq_flavour
(
ips_options
,
sample_consecutive_frames
,
consecutive_frames_max_gap
,
)
def
test_multiseq_gaps
(
self
):
num_batches
=
16
batch_size
=
10
dataset_multiseq
=
MockDataset
(
5
,
max_frame_gap
=
3
)
for
ips_options
in
[[
10
],
[
2
],
[
3
],
[
2
,
3
,
4
]]:
debug_info
=
f
" Images per sequence:
{
ips_options
}
."
sampler
=
SceneBatchSampler
(
dataset_multiseq
,
batch_size
=
batch_size
,
num_batches
=
num_batches
,
images_per_seq_options
=
ips_options
,
sample_consecutive_frames
=
True
,
consecutive_frames_max_gap
=
1
,
)
self
.
assertEqual
(
len
(
sampler
),
num_batches
,
msg
=
debug_info
)
it
=
iter
(
sampler
)
for
_
in
range
(
num_batches
):
batch
=
next
(
it
)
self
.
assertIsNotNone
(
batch
,
"batch is None in"
+
debug_info
)
if
max
(
ips_options
)
>
5
:
# true for our examples
self
.
assertEqual
(
len
(
batch
),
5
,
msg
=
debug_info
)
else
:
# true for our examples
self
.
assertEqual
(
len
(
batch
),
batch_size
,
msg
=
debug_info
)
self
.
_check_frames_are_consecutive
(
batch
,
dataset_multiseq
.
frame_annots
,
debug_info
)
def
_test_multiseq_flavour
(
self
,
ips_options
,
sample_consecutive_frames
,
consecutive_frames_max_gap
,
num_batches
=
16
,
batch_size
=
10
,
):
debug_info
=
(
f
" Images per sequence:
{
ips_options
}
, "
f
"sample_consecutive_frames:
{
sample_consecutive_frames
}
, "
f
"consecutive_frames_max_gap:
{
consecutive_frames_max_gap
}
, "
)
# in this test, either consecutive_frames_max_gap == max_frame_gap,
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
frame_gap
=
consecutive_frames_max_gap
if
consecutive_frames_max_gap
>
0
else
3
dataset_multiseq
=
MockDataset
(
5
,
max_frame_gap
=
frame_gap
)
sampler
=
SceneBatchSampler
(
dataset_multiseq
,
batch_size
=
batch_size
,
num_batches
=
num_batches
,
images_per_seq_options
=
ips_options
,
sample_consecutive_frames
=
sample_consecutive_frames
,
consecutive_frames_max_gap
=
consecutive_frames_max_gap
,
)
self
.
assertEqual
(
len
(
sampler
),
num_batches
,
msg
=
debug_info
)
it
=
iter
(
sampler
)
typical_counts
=
set
()
for
_
in
range
(
num_batches
):
batch
=
next
(
it
)
self
.
assertIsNotNone
(
batch
,
"batch is None in"
+
debug_info
)
# true for our examples
self
.
assertEqual
(
len
(
batch
),
batch_size
,
msg
=
debug_info
)
# find distribution over sequences
counts
=
_count_by_quotient
(
batch
,
10
)
freqs
=
_count_by_quotient
(
counts
.
values
(),
1
)
self
.
assertLessEqual
(
len
(
freqs
),
2
,
msg
=
"We should have maximum of 2 different "
"frequences of sequences in the batch."
+
debug_info
,
)
if
len
(
freqs
)
==
2
:
most_seq_count
=
max
(
*
freqs
.
keys
())
last_seq
=
min
(
*
freqs
.
keys
())
self
.
assertEqual
(
freqs
[
last_seq
],
1
,
msg
=
"Only one odd sequence allowed."
+
debug_info
,
)
else
:
self
.
assertEqual
(
len
(
freqs
),
1
)
most_seq_count
=
next
(
iter
(
freqs
))
self
.
assertIn
(
most_seq_count
,
ips_options
)
typical_counts
.
add
(
most_seq_count
)
if
sample_consecutive_frames
:
self
.
_check_frames_are_consecutive
(
batch
,
dataset_multiseq
.
frame_annots
,
debug_info
,
max_gap
=
consecutive_frames_max_gap
,
)
self
.
assertTrue
(
all
(
i
in
typical_counts
for
i
in
ips_options
),
"Some of the frequency options did not occur among "
f
"the
{
num_batches
}
batches (could be just bad luck)."
+
debug_info
,
)
with
self
.
assertRaises
(
StopIteration
):
batch
=
next
(
it
)
def
_check_frames_are_consecutive
(
self
,
batch
,
annots
,
debug_info
,
max_gap
=
1
):
# make sure that sampled frames are consecutive
for
i
in
range
(
len
(
batch
)
-
1
):
curr_idx
,
next_idx
=
batch
[
i
:
i
+
2
]
if
curr_idx
//
10
==
next_idx
//
10
:
# same sequence
if
max_gap
>
0
:
curr_idx
,
next_idx
=
[
annots
[
idx
][
"frame_annotation"
].
frame_number
for
idx
in
(
curr_idx
,
next_idx
)
]
gap
=
max_gap
else
:
gap
=
1
# we'll check that raw dataset indices are consecutive
self
.
assertLessEqual
(
next_idx
-
curr_idx
,
gap
,
msg
=
debug_info
)
def
_count_by_quotient
(
indices
,
divisor
):
counter
=
defaultdict
(
int
)
for
i
in
indices
:
counter
[
i
//
divisor
]
+=
1
return
counter
tests/implicitron/test_circle_fitting.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
unittest
from
math
import
pi
import
torch
from
pytorch3d.implicitron.tools.circle_fitting
import
(
_signed_area
,
fit_circle_in_2d
,
fit_circle_in_3d
,
)
from
pytorch3d.transforms
import
random_rotation
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
common_testing
import
TestCaseMixin
else
:
from
tests.common_testing
import
TestCaseMixin
class
TestCircleFitting
(
TestCaseMixin
,
unittest
.
TestCase
):
def
setUp
(
self
):
torch
.
manual_seed
(
42
)
def
_assertParallel
(
self
,
a
,
b
,
**
kwargs
):
"""
Given a and b of shape (..., 3) each containing 3D vectors,
assert that correspnding vectors are parallel. Changed sign is ok.
"""
self
.
assertClose
(
torch
.
cross
(
a
,
b
,
dim
=-
1
),
torch
.
zeros_like
(
a
),
**
kwargs
)
def
test_simple_3d
(
self
):
device
=
torch
.
device
(
"cuda:0"
)
for
_
in
range
(
7
):
radius
=
10
*
torch
.
rand
(
1
,
device
=
device
)[
0
]
center
=
10
*
torch
.
rand
(
3
,
device
=
device
)
rot
=
random_rotation
(
device
=
device
)
offset
=
torch
.
rand
(
3
,
device
=
device
)
up
=
torch
.
rand
(
3
,
device
=
device
)
self
.
_simple_3d_test
(
radius
,
center
,
rot
,
offset
,
up
)
def
_simple_3d_test
(
self
,
radius
,
center
,
rot
,
offset
,
up
):
# angles are increasing so the points move in a well defined direction.
angles
=
torch
.
cumsum
(
torch
.
rand
(
17
,
device
=
rot
.
device
),
dim
=
0
)
many
=
torch
.
stack
(
[
torch
.
cos
(
angles
),
torch
.
sin
(
angles
),
torch
.
zeros_like
(
angles
)],
dim
=
1
)
source_points
=
(
many
*
radius
)
@
rot
+
center
[
None
]
# case with no generation
result
=
fit_circle_in_3d
(
source_points
)
self
.
assertClose
(
result
.
radius
,
radius
)
self
.
assertClose
(
result
.
center
,
center
)
self
.
_assertParallel
(
result
.
normal
,
rot
[
2
],
atol
=
1e-5
)
self
.
assertEqual
(
result
.
generated_points
.
shape
,
(
0
,
3
))
# Generate 5 points around the circle
n_new_points
=
5
result2
=
fit_circle_in_3d
(
source_points
,
n_points
=
n_new_points
)
self
.
assertClose
(
result2
.
radius
,
radius
)
self
.
assertClose
(
result2
.
center
,
center
)
self
.
assertClose
(
result2
.
normal
,
result
.
normal
)
self
.
assertEqual
(
result2
.
generated_points
.
shape
,
(
5
,
3
))
observed_points
=
result2
.
generated_points
self
.
assertClose
(
observed_points
[
0
],
observed_points
[
4
],
atol
=
1e-4
)
self
.
assertClose
(
observed_points
[
0
],
source_points
[
0
],
atol
=
1e-5
)
observed_normal
=
torch
.
cross
(
observed_points
[
0
]
-
observed_points
[
2
],
observed_points
[
1
]
-
observed_points
[
3
],
dim
=-
1
,
)
self
.
_assertParallel
(
observed_normal
,
result
.
normal
,
atol
=
1e-4
)
diameters
=
observed_points
[:
2
]
-
observed_points
[
2
:
4
]
self
.
assertClose
(
torch
.
norm
(
diameters
,
dim
=
1
),
diameters
.
new_full
((
2
,),
2
*
radius
)
)
# Regenerate the input points
result3
=
fit_circle_in_3d
(
source_points
,
angles
=
angles
-
angles
[
0
])
self
.
assertClose
(
result3
.
radius
,
radius
)
self
.
assertClose
(
result3
.
center
,
center
)
self
.
assertClose
(
result3
.
normal
,
result
.
normal
)
self
.
assertClose
(
result3
.
generated_points
,
source_points
,
atol
=
1e-5
)
# Test with offset
result4
=
fit_circle_in_3d
(
source_points
,
angles
=
angles
-
angles
[
0
],
offset
=
offset
,
up
=
up
)
self
.
assertClose
(
result4
.
radius
,
radius
)
self
.
assertClose
(
result4
.
center
,
center
)
self
.
assertClose
(
result4
.
normal
,
result
.
normal
)
observed_offsets
=
result4
.
generated_points
-
source_points
# observed_offset is constant
self
.
assertClose
(
observed_offsets
.
min
(
0
).
values
,
observed_offsets
.
max
(
0
).
values
,
atol
=
1e-5
)
# observed_offset has the right length
self
.
assertClose
(
observed_offsets
[
0
].
norm
(),
offset
.
norm
())
self
.
assertClose
(
result
.
normal
.
norm
(),
torch
.
ones
(()))
# component of observed_offset along normal
component
=
torch
.
dot
(
observed_offsets
[
0
],
result
.
normal
)
self
.
assertClose
(
component
.
abs
(),
offset
[
2
].
abs
(),
atol
=
1e-5
)
agree_normal
=
torch
.
dot
(
result
.
normal
,
up
)
>
0
agree_signs
=
component
*
offset
[
2
]
>
0
self
.
assertEqual
(
agree_normal
,
agree_signs
)
def
test_simple_2d
(
self
):
radius
=
7.0
center
=
torch
.
tensor
([
9
,
2.5
])
angles
=
torch
.
cumsum
(
torch
.
rand
(
17
),
dim
=
0
)
many
=
torch
.
stack
([
torch
.
cos
(
angles
),
torch
.
sin
(
angles
)],
dim
=
1
)
source_points
=
(
many
*
radius
)
+
center
[
None
]
result
=
fit_circle_in_2d
(
source_points
)
self
.
assertClose
(
result
.
radius
,
torch
.
tensor
(
radius
))
self
.
assertClose
(
result
.
center
,
center
)
self
.
assertEqual
(
result
.
generated_points
.
shape
,
(
0
,
2
))
# Generate 5 points around the circle
n_new_points
=
5
result2
=
fit_circle_in_2d
(
source_points
,
n_points
=
n_new_points
)
self
.
assertClose
(
result2
.
radius
,
torch
.
tensor
(
radius
))
self
.
assertClose
(
result2
.
center
,
center
)
self
.
assertEqual
(
result2
.
generated_points
.
shape
,
(
5
,
2
))
observed_points
=
result2
.
generated_points
self
.
assertClose
(
observed_points
[
0
],
observed_points
[
4
])
self
.
assertClose
(
observed_points
[
0
],
source_points
[
0
],
atol
=
1e-5
)
diameters
=
observed_points
[:
2
]
-
observed_points
[
2
:
4
]
self
.
assertClose
(
torch
.
norm
(
diameters
,
dim
=
1
),
torch
.
full
((
2
,),
2
*
radius
))
# Regenerate the input points
result3
=
fit_circle_in_2d
(
source_points
,
angles
=
angles
-
angles
[
0
])
self
.
assertClose
(
result3
.
radius
,
torch
.
tensor
(
radius
))
self
.
assertClose
(
result3
.
center
,
center
)
self
.
assertClose
(
result3
.
generated_points
,
source_points
,
atol
=
1e-5
)
def
test_minimum_inputs
(
self
):
fit_circle_in_3d
(
torch
.
rand
(
3
,
3
),
n_points
=
10
)
with
self
.
assertRaisesRegex
(
ValueError
,
"2 points are not enough to determine a circle"
):
fit_circle_in_3d
(
torch
.
rand
(
2
,
3
))
def
test_signed_area
(
self
):
n_points
=
1001
angles
=
torch
.
linspace
(
0
,
2
*
pi
,
n_points
)
radius
=
0.85
center
=
torch
.
rand
(
2
)
circle
=
center
+
radius
*
torch
.
stack
(
[
torch
.
cos
(
angles
),
torch
.
sin
(
angles
)],
dim
=
1
)
circle_area
=
torch
.
tensor
(
pi
*
radius
*
radius
)
self
.
assertClose
(
_signed_area
(
circle
),
circle_area
)
# clockwise is negative
self
.
assertClose
(
_signed_area
(
circle
.
flip
(
0
)),
-
circle_area
)
# Semicircles
self
.
assertClose
(
_signed_area
(
circle
[:
(
n_points
+
1
)
//
2
]),
circle_area
/
2
)
self
.
assertClose
(
_signed_area
(
circle
[
n_points
//
2
:]),
circle_area
/
2
)
# A straight line bounds no area
self
.
assertClose
(
_signed_area
(
torch
.
rand
(
2
,
2
)),
torch
.
tensor
(
0.0
))
# Letter 'L' written anticlockwise.
L_shape
=
[[
0
,
1
],
[
0
,
0
],
[
1
,
0
]]
# Triangle area is 0.5 * b * h.
self
.
assertClose
(
_signed_area
(
torch
.
tensor
(
L_shape
)),
torch
.
tensor
(
0.5
))
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment