Unverified Commit cdd2142d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

implicitron v0 (#1133)


Co-authored-by: default avatarJeremy Francis Reizenstein <bottler@users.noreply.github.com>
parent 0e377c68
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# TODO: all this potentially goes to PyTorch3D
import math
from typing import Tuple
import pytorch3d as pt3d
import torch
from pytorch3d.renderer.cameras import CamerasBase
def jitter_extrinsics(
R: torch.Tensor,
T: torch.Tensor,
max_angle: float = (math.pi * 2.0),
translation_std: float = 1.0,
scale_std: float = 0.3,
):
"""
Jitter the extrinsic camera parameters `R` and `T` with a random similarity
transformation. The transformation rotates by a random angle between [0, max_angle];
scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is
a random sample from a normal distrubtion with zero mean and variance scale_std;
and translates by a 3D offset sampled from N(0, translation_std).
"""
assert all(x >= 0.0 for x in (max_angle, translation_std, scale_std))
N = R.shape[0]
R_jit = pt3d.transforms.random_rotations(1, device=R.device)
R_jit = pt3d.transforms.so3_exponential_map(
pt3d.transforms.so3_log_map(R_jit) * max_angle
)
T_jit = torch.randn_like(R_jit[:1, :, 0]) * translation_std
rigid_transform = pt3d.ops.eyes(dim=4, N=N, device=R.device)
rigid_transform[:, :3, :3] = R_jit.expand(N, 3, 3)
rigid_transform[:, 3, :3] = T_jit.expand(N, 3)
scale_jit = torch.exp(torch.randn_like(T_jit[:, 0]) * scale_std).expand(N)
return apply_camera_alignment(R, T, rigid_transform, scale_jit)
def apply_camera_alignment(
R: torch.Tensor,
T: torch.Tensor,
rigid_transform: torch.Tensor,
scale: torch.Tensor,
):
"""
Args:
R: Camera rotation matrix of shape (N, 3, 3).
T: Camera translation of shape (N, 3).
rigid_transform: A tensor of shape (N, 4, 4) representing a batch of
N 4x4 tensors that map the scene pointcloud from misaligned coords
to the aligned space.
scale: A list of N scaling factors. A tensor of shape (N,)
Returns:
R_aligned: The aligned rotations R.
T_aligned: The aligned translations T.
"""
R_rigid = rigid_transform[:, :3, :3]
T_rigid = rigid_transform[:, 3:, :3]
R_aligned = R_rigid.permute(0, 2, 1).bmm(R)
T_aligned = scale[:, None] * (T - (T_rigid @ R_aligned)[:, 0])
return R_aligned, T_aligned
def get_min_max_depth_bounds(cameras, scene_center, scene_extent):
"""
Estimate near/far depth plane as:
near = dist(cam_center, self.scene_center) - self.scene_extent
far = dist(cam_center, self.scene_center) + self.scene_extent
"""
cam_center = cameras.get_camera_center()
center_dist = (
((cam_center - scene_center.to(cameras.R)[None]) ** 2)
.sum(dim=-1)
.clamp(0.001)
.sqrt()
)
center_dist = center_dist.clamp(scene_extent + 1e-3)
min_depth = center_dist - scene_extent
max_depth = center_dist + scene_extent
return min_depth, max_depth
def volumetric_camera_overlaps(
cameras: CamerasBase,
scene_extent: float = 8.0,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
resol: int = 16,
weigh_by_ray_angle: bool = True,
):
"""
Compute the overlaps between viewing frustrums of all pairs of cameras
in `cameras`.
"""
device = cameras.device
ba = cameras.R.shape[0]
n_vox = int(resol ** 3)
grid = pt3d.structures.Volumes(
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
voxel_size=2.0 * scene_extent / resol,
).get_coord_grid(world_coordinates=True)
grid = grid.view(1, n_vox, 3).expand(ba, n_vox, 3)
gridp = cameras.transform_points(grid, eps=1e-2)
proj_in_camera = (
torch.prod((gridp[..., :2].abs() <= 1.0), dim=-1)
* (gridp[..., 2] > 0.0).float()
) # ba x n_vox
if weigh_by_ray_angle:
rays = torch.nn.functional.normalize(
grid - cameras.get_camera_center()[:, None], dim=-1
)
rays_masked = rays * proj_in_camera[..., None]
# - slow and readable:
# inter = torch.zeros(ba, ba)
# for i1 in range(ba):
# for i2 in range(ba):
# inter[i1, i2] = (
# 1 + (rays_masked[i1] * rays_masked[i2]
# ).sum(dim=-1)).sum()
# - fast:
rays_masked = rays_masked.view(ba, n_vox * 3)
inter = n_vox + (rays_masked @ rays_masked.t())
else:
inter = proj_in_camera @ proj_in_camera.t()
mass = torch.diag(inter)
iou = inter / (mass[:, None] + mass[None, :] - inter).clamp(0.1)
return iou
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from math import pi
from typing import Optional
import torch
from pytorch3d.common.compat import eigh, lstsq
def _get_rotation_to_best_fit_xy(
points: torch.Tensor, centroid: torch.Tensor
) -> torch.Tensor:
"""
Returns a rotation r such that points @ r has a best fit plane
parallel to the xy plane
Args:
points: (N, 3) tensor of points in 3D
centroid: (3,) their centroid
Returns:
(3,3) tensor rotation matrix
"""
points_centered = points - centroid[None]
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
def _signed_area(path: torch.Tensor) -> torch.Tensor:
"""
Calculates the signed area / Lévy area of a 2D path. If the path is closed,
i.e. ends where it starts, this is the integral of the winding number over
the whole plane. If not, consider a closed path made by adding a straight
line from the end to the start; the signed area is the integral of the
winding number (also over the plane) with respect to that closed path.
If this number is positive, it indicates in some sense that the path
turns anticlockwise more than clockwise, and vice versa.
Args:
path: N x 2 tensor of points.
Returns:
signed area, shape ()
"""
# This calculation is a sum of areas of triangles of the form
# (path[0], path[i], path[i+1]), where each triangle is half a
# parallelogram.
x, y = (path[1:] - path[:1]).unbind(1)
return (y[1:] * x[:-1] - x[1:] * y[:-1]).sum() * 0.5
@dataclass(frozen=True)
class Circle2D:
"""
Contains details of a circle in a plane.
Members
center: tensor shape (2,)
radius: tensor shape ()
generated_points: points around the circle, shape (n_points, 2)
"""
center: torch.Tensor
radius: torch.Tensor
generated_points: torch.Tensor
def fit_circle_in_2d(
points2d, *, n_points: int = 0, angles: Optional[torch.Tensor] = None
) -> Circle2D:
"""
Simple best fitting of a circle to 2D points. In particular, the circle which
minimizes the sum of the squares of the squared-distances to the circle.
Finds (a,b) and r to minimize the sum of squares (over the x,y pairs) of
r**2 - [(x-a)**2+(y-b)**2]
i.e.
(2*a)*x + (2*b)*y + (r**2 - a**2 - b**2)*1 - (x**2 + y**2)
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to match the movement of points in points2d, as judged by its
signed area. If `angles` are provided, then n_points is ignored, and points
along the circle at the given angles are returned, with the starting point
and direction as before.
(Note that `generated_points` is affected by the order of the points in
points2d, but the other outputs are not.)
Args:
points2d: N x 2 tensor of 2D points
n_points: number of points to generate on the circle, if angles not given
angles: optional angles in radians of points to generate.
Returns:
Circle2D object
"""
design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1)
rhs = (points2d ** 2).sum(1)
n_provided = points2d.shape[0]
if n_provided < 3:
raise ValueError(f"{n_provided} points are not enough to determine a circle")
solution = lstsq(design, rhs)
center = solution[:2] / 2
radius = torch.sqrt(solution[2] + (center ** 2).sum())
if n_points > 0:
if angles is not None:
warnings.warn("n_points ignored because angles provided")
else:
angles = torch.linspace(0, 2 * pi, n_points, device=points2d.device)
if angles is not None:
initial_direction_xy = (points2d[0] - center).unbind()
initial_angle = torch.atan2(initial_direction_xy[1], initial_direction_xy[0])
with torch.no_grad():
anticlockwise = _signed_area(points2d) > 0
if anticlockwise:
use_angles = initial_angle + angles
else:
use_angles = initial_angle - angles
generated_points = center[None] + radius * torch.stack(
[torch.cos(use_angles), torch.sin(use_angles)], dim=-1
)
else:
generated_points = points2d.new_zeros(0, 2)
return Circle2D(center=center, radius=radius, generated_points=generated_points)
@dataclass(frozen=True)
class Circle3D:
"""
Contains details of a circle in 3D.
Members
center: tensor shape (3,)
radius: tensor shape ()
normal: tensor shape (3,)
generated_points: points around the circle, shape (n_points, 3)
"""
center: torch.Tensor
radius: torch.Tensor
normal: torch.Tensor
generated_points: torch.Tensor
def fit_circle_in_3d(
points,
*,
n_points: int = 0,
angles: Optional[torch.Tensor] = None,
offset: Optional[torch.Tensor] = None,
up: Optional[torch.Tensor] = None,
) -> Circle3D:
"""
Simple best fit circle to 3D points. Uses circle_2d in the
least-squares best fit plane.
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to be match the movement of points. If angles is provided, then n_points
is ignored, and points along the circle at the given angles are returned,
with the starting point and direction as before.
Further, an offset can be given to add to the generated points; this is
interpreted in a rotated coordinate system where (0, 0, 1) is normal to the
circle, specifically the normal which is approximately in the direction of a
given `up` vector. The remaining rotation is disambiguated in an unspecified
but deterministic way.
(Note that `generated_points` is affected by the order of the points in
points, but the other outputs are not.)
Args:
points2d: N x 3 tensor of 3D points
n_points: number of points to generate on the circle
angles: optional angles in radians of points to generate.
offset: optional tensor (3,), a displacement expressed in a "canonical"
coordinate system to add to the generated points.
up: optional tensor (3,), a vector which helps define the
"canonical" coordinate system for interpretting `offset`.
Required if offset is used.
Returns:
Circle3D object
"""
centroid = points.mean(0)
r = _get_rotation_to_best_fit_xy(points, centroid)
normal = r[:, 2]
rotated_points = (points - centroid) @ r
result_2d = fit_circle_in_2d(
rotated_points[:, :2], n_points=n_points, angles=angles
)
center_3d = result_2d.center @ r[:, :2].t() + centroid
n_generated_points = result_2d.generated_points.shape[0]
if n_generated_points > 0:
generated_points_in_plane = torch.cat(
[
result_2d.generated_points,
torch.zeros_like(result_2d.generated_points[:, :1]),
],
dim=1,
)
if offset is not None:
if up is None:
raise ValueError("Missing `up` input for interpreting offset")
with torch.no_grad():
swap = torch.dot(up, normal) < 0
if swap:
# We need some rotation which takes +z to -z. Here's one.
generated_points_in_plane += offset * offset.new_tensor([1, -1, -1])
else:
generated_points_in_plane += offset
generated_points = generated_points_in_plane @ r.t() + centroid
else:
generated_points = points.new_zeros(0, 3)
return Circle3D(
radius=result_2d.radius,
center=center_3d,
normal=normal,
generated_points=generated_points,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import warnings
from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from omegaconf import DictConfig, OmegaConf, open_dict
"""
This functionality allows a configurable system to be determined in a dataclass-type
way. It is a generalization of omegaconf's "structured", in the dataclass case.
Core functionality:
- Configurable -- A base class used to label a class as being one which uses this
system. Uses class members and __post_init__ like a dataclass.
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
- get_default_args -- gets an omegaconf.DictConfig for initializing
a given class or calling a given function.
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
In addition, a Configurable may contain members whose type is decided at runtime.
- ReplaceableBase -- As a base instead of Configurable, labels a class to say that
any child class can be used instead.
- registry -- A global store of named child classes of ReplaceableBase classes.
Used as `@registry.register` decorator on class definition.
Additional utility functions:
- remove_unused_components -- used for simplifying a DictConfig instance.
- get_default_args_field -- default for DictConfig member of another configurable.
1. The simplest usage of this functionality is as follows. First a schema is defined
in dataclass style.
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
It can be used like
b_args = get_default_args(B)
b = B(**b_args)
In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
something like the following. (The modification itself is done by the function
`expand_args_fields`, which is called inside `get_default_args`.)
@dataclasses.dataclass
class A:
n: int = 9
@dataclasses.dataclass
class B:
a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
def __post_init__(self):
self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
you can give a base class and the implementation is looked up by name in the global
`registry` in this module. E.g.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
will expand to
@dataclasses.dataclass
class A:
k: int = 1
@dataclasses.dataclass
class A1(A):
m: int = 3
@dataclasses.dataclass
class A2(A):
n: str = "2"
@dataclasses.dataclass
class B:
a_class_type: str = "A2"
a_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
def __post_init__(self):
if self.a_class_type == "A1":
self.a = A1(**self.a_A1_args)
elif self.a_class_type == "A2":
self.a = A2(**self.a_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them.
In addition, you can call get_default_args on a function or class to get
the DictConfig of its defaulted arguments, assuming those are all things
which DictConfig is happy with. If you want to use such a thing as a member
of another configured class, `get_default_args_field` is a helper.
"""
_unprocessed_warning: str = (
" must be processed before it can be used."
+ " This is done by calling expand_args_fields "
+ "or get_default_args on it."
)
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
class ReplaceableBase:
"""
Base class for dataclass-style classes which
can be stored in the registry.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
class Configurable:
"""
This indicates a class which is not ReplaceableBase
but still needs to be
expanded into a dataclass with expand_args_fields.
This expansion is delayed.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not Configurable and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
_X = TypeVar("X", bound=ReplaceableBase)
class _Registry:
"""
Register from names to classes. In particular, we say that direct subclasses of
ReplaceableBase are "base classes" and we register subclasses of each base class
in a separate namespace.
"""
def __init__(self) -> None:
self._mapping: Dict[
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
] = defaultdict(dict)
def register(self, some_class: Type[_X]) -> Type[_X]:
"""
A class decorator, to register a class in self.
"""
name = some_class.__name__
self._register(some_class, name=name)
return some_class
def _register(
self,
some_class: Type[ReplaceableBase],
*,
base_class: Optional[Type[ReplaceableBase]] = None,
name: str,
) -> None:
"""
Register a new member.
Args:
cls: the new member
base_class: (optional) what the new member is a type for
name: name for the new member
"""
if base_class is None:
base_class = self._base_class_from_class(some_class)
if base_class is None:
raise ValueError(
f"Cannot register {some_class}. Cannot tell what it is."
)
if some_class is base_class:
raise ValueError(f"Attempted to register the base class {some_class}")
self._mapping[base_class][name] = some_class
def get(
self, base_class_wanted: Type[ReplaceableBase], name: str
) -> Type[ReplaceableBase]:
"""
Retrieve a class from the registry by name
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
name: what to look for
Returns:
class type
"""
if self._is_base_class(base_class_wanted):
base_class = base_class_wanted
else:
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
result = self._mapping[base_class].get(name)
if result is None:
raise ValueError(f"{name} has not been registered.")
if not issubclass(result, base_class_wanted):
raise ValueError(
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
)
return result
def get_all(
self, base_class_wanted: Type[ReplaceableBase]
) -> List[Type[ReplaceableBase]]:
"""
Retrieve all registered implementations from the registry
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
"""
if self._is_base_class(base_class_wanted):
return list(self._mapping[base_class_wanted].values())
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
return [
class_
for class_ in self._mapping[base_class].values()
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
]
@staticmethod
def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
"""
Return whether the given type is a direct subclass of ReplaceableBase
and so gets used as a namespace.
"""
return ReplaceableBase in some_class.__bases__
@staticmethod
def _base_class_from_class(
some_class: Type[ReplaceableBase],
) -> Optional[Type[ReplaceableBase]]:
"""
Find the parent class of some_class which inherits ReplaceableBase, or None
"""
for base in some_class.mro()[-3::-1]:
if base is not ReplaceableBase and issubclass(base, ReplaceableBase):
return base
return None
# Global instance of the registry
registry = _Registry()
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]:
"""
Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called
from run_auto_creation.
Args:
name: name of the member
type_: declared type of the member
pluggable: True if the member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
Returns:
Function taking one argument, the object whose member should be
initialized.
"""
def inner(self):
expand_args_fields(type_)
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of
# the chosen class has been registered since our class was processed
# (i.e. expanded). A DictConfig which comes from our get_default_args
# (which might have triggered the processing) will contain the old default
# values for the members of the chosen class. Changes to those defaults which
# were made in the redefinition will not be reflected here.
warnings.warn(f"New implementation of {type_name} is being chosen.")
expand_args_fields(chosen_class)
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
return inner_pluggable if pluggable else inner
def run_auto_creation(self: Any) -> None:
"""
Run all the functions named in self._creation_functions.
"""
for create_function in self._creation_functions:
getattr(self, create_function)()
def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
"""
Get the DictConfig of args to call C - which might be a type or a function.
If C is a subclass of Configurable or ReplaceableBase, we make sure
it has been processed with expand_args_fields. If C is a dataclass,
including a subclass of Configurable or ReplaceableBase, the output
will be a typed DictConfig.
Args:
C: the class or function to be processed
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
while it is already being processed.
Returns:
new DictConfig object
"""
if C is None:
return DictConfig({})
if _is_configurable_class(C):
if C in _do_not_process:
raise ValueError(
f"Internal recursion error. Need processed {C},"
f" but cannot get it. _do_not_process={_do_not_process}"
)
# This is safe to run multiple times. It will return
# straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process)
kwargs = {}
if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
out.pop(field, None)
return out
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
# returns dict of keyword args of a callable C
sig = inspect.signature(C)
for pname, defval in dict(sig.parameters).items():
if defval.default == inspect.Parameter.empty:
# print('skipping %s' % pname)
continue
else:
kwargs[pname] = copy.deepcopy(defval.default)
return DictConfig(kwargs)
def _is_actually_dataclass(some_class) -> bool:
# Return whether the class some_class has been processed with
# the dataclass annotation. This is more specific than
# dataclasses.is_dataclass which returns True on anything
# deriving from a dataclass.
# Checking for __init__ would also work for our purpose.
return "__dataclass_fields__" in some_class.__dict__
def expand_args_fields(
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_X]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.
For classes of type ReplaceableBase, you can add some_class to the registry before
or after calling this function. But potential inner classes need to be registered
before this function is run on the outer class.
The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace
x: X
and optionally
x_class_type: str = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: str = "UNDEFAULTED"
without adding the optional things if they are already there.
Similarly, if X is a subclass of Configurable,
x: X
and optionally
def create_x(self):...
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = X(self.x_args)
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
including those from base classes.
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
- _processed_members: a Set[str] of all the members which have been transformed.
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
and is called by this function, we let it specify any class currently
being processed, to make sure we don't try to process a class while
it is already being processed.
Returns:
some_class itself, which has been modified in place. This
allows this function to be used as a class decorator.
"""
if _is_actually_dataclass(some_class):
return some_class
# The functions this class's run_auto_creation will run.
creation_functions: List[str] = []
# The classes which this type knows about from the registry
# We could use a weakref.WeakValueDictionary here which would mean
# that we don't warn if the class we should have expected is elsewhere
# unused.
known_implementations: Dict[str, Type] = {}
# Names of members which have been processed.
processed_members: Set[str] = set()
# For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is
# because dataclasses expect to inherit dataclasses and not unprocessed
# dataclasses.
for base in some_class.mro()[-3:0:-1]:
if base is ReplaceableBase:
continue
if base is Configurable:
continue
if not issubclass(base, (Configurable, ReplaceableBase)):
continue
expand_args_fields(base, _do_not_process=_do_not_process)
if "_creation_functions" in base.__dict__:
creation_functions.extend(base._creation_functions)
if "_known_implementations" in base.__dict__:
known_implementations.update(base._known_implementations)
if "_processed_members" in base.__dict__:
processed_members.update(base._processed_members)
to_process: List[Tuple[str, Type, bool]] = []
if "__annotations__" in some_class.__dict__:
for name, type_ in some_class.__annotations__.items():
if not isinstance(type_, type):
# type_ could be something like typing.Tuple
continue
if (
issubclass(type_, ReplaceableBase)
and ReplaceableBase in type_.__bases__
):
to_process.append((name, type_, True))
elif issubclass(type_, Configurable):
to_process.append((name, type_, False))
for name, type_, pluggable in to_process:
_process_member(
name=name,
type_=type_,
pluggable=pluggable,
some_class=cast(type, some_class),
creation_functions=creation_functions,
_do_not_process=_do_not_process,
known_implementations=known_implementations,
)
processed_members.add(name)
for key, count in Counter(creation_functions).items():
if count > 1:
warnings.warn(f"Clash with {key} in a base class.")
some_class._creation_functions = tuple(creation_functions)
some_class._processed_members = processed_members
some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class)
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
Get a dataclass field which defaults to get_default_args(...)
Args:
As for get_default_args.
Returns:
function to return new DictConfig object
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
return dataclasses.field(default_factory=create)
def _process_member(
*,
name: str,
type_: Type,
pluggable: bool,
some_class: Type,
creation_functions: List[str],
_do_not_process: Tuple[type, ...],
known_implementations: Dict[str, Type],
) -> None:
"""
Make the modification (of expand_args_fields) to some_class for a single member.
Args:
name: member name
type_: member declared type
plugglable: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields.
known_implementations: (MODIFIED IN PLACE) known types from the registry
"""
# Because we are adding defaultable members, make
# sure they go at the end of __annotations__ in case
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
if pluggable:
type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__:
some_class.__annotations__[type_name] = str
setattr(some_class, type_name, "UNDEFAULTED")
for derived_type in registry.get_all(type_):
if derived_type in _do_not_process:
continue
if issubclass(derived_type, some_class):
# When derived_type is some_class we have a simple
# recursion to avoid. When it's a strict subclass the
# situation is even worse.
continue
known_implementations[derived_type.__name__] = derived_type
args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}"
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
derived_type, _do_not_process=_do_not_process + (some_class,)
),
)
else:
args_name = name + ARGS_SUFFIX
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
if issubclass(type_, some_class) or type_ in _do_not_process:
raise ValueError(f"Cannot process {type_} inside {some_class}")
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
type_,
_do_not_process=_do_not_process + (some_class,),
),
)
creation_function_name = f"create_{name}"
if not hasattr(some_class, creation_function_name):
setattr(
some_class,
creation_function_name,
_default_create(name, type_, pluggable),
)
creation_functions.append(creation_function_name)
def remove_unused_components(dict_: DictConfig) -> None:
"""
Assuming dict_ represents the state of a configurable,
modify it to remove all the portions corresponding to
pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed.
Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance
"""
keys = [key for key in dict_ if isinstance(key, str)]
suffix_length = len(TYPE_SUFFIX)
replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
for replaceable in replaceables:
selected_type = dict_[replaceable + TYPE_SUFFIX]
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
with open_dict(dict_):
for key in args_keys:
if key.startswith(replaceable + "_") and key != expect:
del dict_[key]
for key in dict_:
if isinstance(dict_.get(key), DictConfig):
remove_unused_components(dict_[key])
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as Fu
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
def cleanup_eval_depth(
point_cloud: Pointclouds,
camera: CamerasBase,
depth: torch.Tensor,
mask: torch.Tensor,
sigma: float = 0.01,
image=None,
):
ba, _, H, W = depth.shape
pcl = point_cloud.points_padded()
n_pts = point_cloud.num_points_per_cloud()
pcl_mask = (
torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None]
< n_pts[:, None]
).type_as(pcl)
pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1]
pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1]
depth_and_idx = torch.cat(
(
depth,
torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth),
),
dim=1,
)
depth_and_idx_sampled = Fu.grid_sample(
depth_and_idx, -pcl_proj[:, None], mode="nearest"
)[:, :, 0].view(ba, 2, -1)
depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1)
df = (depth_sampled[:, 0] - pcl_depth).abs()
# the threshold is a sigma-multiple of the standard deviation of the depth
mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
std = (
wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
.clamp(1e-4)
.sqrt()
.view(ba, -1)
)
good_df_thr = std * sigma
good_depth = (df <= good_df_thr).float() * pcl_mask
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
# pyre-ignore[16]: scatter_add_
good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)
good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()
# if float(torch.rand(1)) > 0.95:
# depth_ok = depth * good_depth_mask
# # visualize
# visdom_env = 'depth_cleanup_dbg'
# from visdom import Visdom
# # from tools.vis_utils import make_depth_image
# from pytorch3d.vis.plotly_vis import plot_scene
# viz = Visdom()
# show_pcls = {
# 'pointclouds': point_cloud,
# }
# for d, nm in zip(
# (depth, depth_ok),
# ('pointclouds_unproj', 'pointclouds_unproj_ok'),
# ):
# pointclouds_unproj = get_rgbd_point_cloud(
# camera, image, d,
# )
# if int(pointclouds_unproj.num_points_per_cloud()) > 0:
# show_pcls[nm] = pointclouds_unproj
# scene_dict = {'1': {
# **show_pcls,
# 'cameras': camera,
# }}
# scene = plot_scene(
# scene_dict,
# pointcloud_max_points=5000,
# pointcloud_marker_size=1.5,
# camera_scale=1.0,
# )
# viz.plotlyplot(scene, env=visdom_env, win='scene')
# # depth_image_ok = make_depth_image(depths_ok, masks)
# # viz.images(depth_image_ok, env=visdom_env, win='depth_ok')
# # depth_image = make_depth_image(depths, masks)
# # viz.images(depth_image, env=visdom_env, win='depth')
# # # viz.images(rgb_rendered, env=visdom_env, win='images_render')
# # viz.images(images, env=visdom_env, win='images')
# import pdb; pdb.set_trace()
return good_depth_mask
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import Scale
def generate_eval_video_cameras(
train_cameras,
n_eval_cams: int = 100,
trajectory_type: str = "figure_eight",
trajectory_scale: float = 0.2,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
focal_length: Optional[torch.FloatTensor] = None,
principal_point: Optional[torch.FloatTensor] = None,
time: Optional[torch.FloatTensor] = None,
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
Args:
train_dataset: The training dataset object.
n_eval_cams: Number of cameras in the trajectory.
trajectory_type: The type of the camera trajectory. Can be one of:
circular_lsq_fit: Camera centers follow a trajectory obtained
by fitting a 3D circle to train_cameras centers.
All cameras are looking towards scene_center.
figure_eight: Figure-of-8 trajectory around the center of the
central camera of the training dataset.
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
of a figure-eight knot
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
trajectory_scale: The extent of the trajectory.
up: The "up" vector of the scene (=the normal of the scene floor).
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = train_cameras.get_camera_center()
# get the nearest camera center to the mean of centers
mean_camera_idx = (
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
.sum(dim=1)
.min(dim=0)
.indices
)
# generate the knot trajectory in canonical coords
if time is None:
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
else:
assert time.numel() == n_eval_cams
if trajectory_type == "trefoil_knot":
traj = _trefoil_knot(time)
elif trajectory_type == "figure_eight_knot":
traj = _figure_eight_knot(time)
elif trajectory_type == "figure_eight":
traj = _figure_eight(time)
else:
raise ValueError(f"bad trajectory type: {trajectory_type}")
traj[:, 2] -= traj[:, 2].max()
# transform the canonical knot to the coord frame of the mean camera
mean_camera = PerspectiveCameras(
**{
k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
for k in ("focal_length", "principal_point", "R", "T")
}
)
traj_trans = Scale(cam_centers.std(dim=0).mean() * trajectory_scale).compose(
mean_camera.get_world_to_view_transform().inverse()
)
if traj_offset_canonical is not None:
traj_trans = traj_trans.translate(
torch.FloatTensor(traj_offset_canonical)[None].to(traj)
)
traj = traj_trans.transform_points(traj)
plane_normal = _fit_plane(cam_centers)[:, 0]
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
elif trajectory_type == "circular_lsq_fit":
### fit plane to the camera centers
# get the center of the plane as the median of the camera centers
cam_centers = train_cameras.get_camera_center()
if time is not None:
angle = time
else:
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams).to(cam_centers)
fit = fit_circle_in_3d(
cam_centers,
angles=angle,
offset=angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None
else None,
up=angle.new_tensor(up),
)
traj = fit.generated_points
# scalethe trajectory
_t_mu = traj.mean(dim=0, keepdim=True)
traj = (traj - _t_mu) * trajectory_scale + _t_mu
plane_normal = fit.normal
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
else:
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
if traj_offset is not None:
traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)
# point all cameras towards the center of the scene
R, T = look_at_view_transform(
eye=traj,
at=(scene_center,), # (1, 3)
up=(up,), # (1, 3)
device=traj.device,
)
# get the average focal length and principal point
if focal_length is None:
focal_length = train_cameras.focal_length.mean(dim=0).repeat(n_eval_cams, 1)
if principal_point is None:
principal_point = train_cameras.principal_point.mean(dim=0).repeat(
n_eval_cams, 1
)
test_cameras = PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
device=focal_length.device,
)
# _visdom_plot_scene(
# train_cameras,
# test_cameras,
# )
return test_cameras
def _disambiguate_normal(normal, up):
up_t = torch.tensor(up).to(normal)
flip = (up_t * normal).sum().sign()
up = normal * flip
up = up.tolist()
return up
def _fit_plane(x):
x = x - x.mean(dim=0)[None]
cov = (x.t() @ x) / x.shape[0]
_, e_vec = eigh(cov)
return e_vec
def _visdom_plot_scene(
train_cameras,
test_cameras,
) -> None:
from pytorch3d.vis.plotly_vis import plot_scene
p = plot_scene(
{
"scene": {
"train_cams": train_cameras,
"test_cams": test_cameras,
}
}
)
from visdom import Visdom
viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
x = (2 + (2 * t).cos()) * (3 * t).cos()
y = (2 + (2 * t).cos()) * (3 * t).sin()
z = (4 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
x = t.sin() + 2 * (2 * t).sin()
y = t.cos() - 2 * (2 * t).cos()
z = -(3 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
x = t.cos()
y = (2 * t).sin() / 2
z = t.sin() * z_scale
return torch.stack((x, y, z), dim=-1)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import torch
def mask_background(
image_rgb: torch.Tensor,
mask_fg: torch.Tensor,
dim_color: int = 1,
bg_color: Union[torch.Tensor, str, float] = 0.0,
) -> torch.Tensor:
"""
Mask the background input image tensor `image_rgb` with `bg_color`.
The background regions are obtained from the binary foreground segmentation
mask `mask_fg`.
"""
tgt_view = [1, 1, 1, 1]
tgt_view[dim_color] = 3
# obtain the background color tensor
if isinstance(bg_color, torch.Tensor):
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
elif isinstance(bg_color, float):
bg_color_t = torch.tensor(
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
).view(*tgt_view)
elif isinstance(bg_color, str):
if bg_color == "white":
bg_color_t = image_rgb.new_ones(tgt_view)
elif bg_color == "black":
bg_color_t = image_rgb.new_zeros(tgt_view)
else:
raise ValueError(_invalid_color_error_msg(bg_color))
else:
raise ValueError(_invalid_color_error_msg(bg_color))
# cast to the image_rgb's type
mask_fg = mask_fg.type_as(image_rgb)
# mask the bg
image_masked = mask_fg * image_rgb + (1 - mask_fg) * bg_color_t
return image_masked
def _invalid_color_error_msg(bg_color) -> str:
return (
f"Invalid bg_color={bg_color}. Plese set bg_color to a 3-element"
+ " tensor. or a string (white | black), or a float."
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
def eval_depth(
pred: torch.Tensor,
gt: torch.Tensor,
crop: int = 1,
mask: Optional[torch.Tensor] = None,
get_best_scale: bool = True,
mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluate the depth error between the prediction `pred` and the ground
truth `gt`.
Args:
pred: A tensor of shape (N, 1, H, W) denoting the predicted depth maps.
gt: A tensor of shape (N, 1, H, W) denoting the ground truth depth maps.
crop: The number of pixels to crop from the border.
mask: A mask denoting the valid regions of the gt depth.
get_best_scale: If `True`, estimates a scaling factor of the predicted depth
that yields the best mean squared error between `pred` and `gt`.
This is typically enabled for cases where predicted reconstructions
are inherently defined up to an arbitrary scaling factor.
mask_thr: A constant used to threshold the `mask` to specify the valid
regions.
best_scale_clamp_thr: The threshold for clamping the divisor in best
scale estimation.
Returns:
mse_depth: Mean squared error between `pred` and `gt`.
abs_depth: Mean absolute difference between `pred` and `gt`.
"""
# chuck out the border
if crop > 0:
gt = gt[:, :, crop:-crop, crop:-crop]
pred = pred[:, :, crop:-crop, crop:-crop]
if mask is not None:
# mult gt by mask
if crop > 0:
mask = mask[:, :, crop:-crop, crop:-crop]
gt = gt * (mask > mask_thr).float()
dmask = (gt > 0.0).float()
dmask_mass = torch.clamp(dmask.sum((1, 2, 3)), 1e-4)
if get_best_scale:
# mult preds by a scalar "scale_best"
# s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None]
df = gt - pred
mse_depth = (dmask * (df ** 2)).sum((1, 2, 3)) / dmask_mass
abs_depth = (dmask * df.abs()).sum((1, 2, 3)) / dmask_mass
return mse_depth, abs_depth
def estimate_depth_scale_factor(pred, gt, mask, clamp_thr):
xy = pred * gt * mask
xx = pred * pred * mask
scale_best = xy.mean((1, 2, 3)) / torch.clamp(xx.mean((1, 2, 3)), clamp_thr)
return scale_best
def calc_psnr(
x: torch.Tensor,
y: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
"""
mse = calc_mse(x, y, mask=mask)
psnr = torch.log10(mse.clamp(1e-10)) * (-10.0)
return psnr
def calc_mse(
x: torch.Tensor,
y: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Calculates the mean square error between tensors `x` and `y`.
"""
if mask is None:
return torch.mean((x - y) ** 2)
else:
return (((x - y) ** 2) * mask).sum() / mask.expand_as(x).sum().clamp(1e-5)
def calc_bce(
pred: torch.Tensor,
gt: torch.Tensor,
equal_w: bool = True,
pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None,
) -> torch.Tensor:
"""
Calculates the binary cross entropy.
"""
if pred_eps > 0.0:
# up/low bound the predictions
pred = torch.clamp(pred, pred_eps, 1.0 - pred_eps)
if mask is None:
mask = torch.ones_like(gt)
if equal_w:
mask_fg = (gt > 0.5).float() * mask
mask_bg = (1 - mask_fg) * mask
weight = mask_fg / mask_fg.sum().clamp(1.0) + mask_bg / mask_bg.sum().clamp(1.0)
# weight sum should be at this point ~2
weight = weight * (weight.numel() / weight.sum().clamp(1.0))
else:
weight = torch.ones_like(gt) * mask
if lerp_bound is not None:
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else:
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
def binary_cross_entropy_lerp(
pred: torch.Tensor,
gt: torch.Tensor,
weight: torch.Tensor,
lerp_bound: float,
):
"""
Binary cross entropy which avoids exploding gradients by linearly
extrapolating the log function for log(1-pred) mad log(pred) whenever
pred or 1-pred is smaller than lerp_bound.
"""
loss = log_lerp(1 - pred, lerp_bound) * (1 - gt) + log_lerp(pred, lerp_bound) * gt
loss_reduced = -(loss * weight).sum() / weight.sum().clamp(1e-4)
return loss_reduced
def log_lerp(x: torch.Tensor, b: float):
"""
Linearly extrapolated log for x < b.
"""
assert b > 0
return torch.where(x >= b, x.log(), math.log(b) + (x - b) / b)
def rgb_l1(
pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Calculates the mean absolute error between the predicted colors `pred`
and ground truth colors `target`.
"""
if mask is None:
mask = torch.ones_like(pred[:, :1])
return ((pred - target).abs() * mask).sum(dim=(1, 2, 3)) / mask.sum(
dim=(1, 2, 3)
).clamp(1)
def huber(dfsq: torch.Tensor, scaling: float = 0.03) -> torch.Tensor:
"""
Calculates the huber function of the input squared error `dfsq`.
The function smoothly transitions from a region with unit gradient
to a hyperbolic function at `dfsq=scaling`.
"""
loss = (safe_sqrt(1 + dfsq / (scaling * scaling), eps=1e-4) - 1) * scaling
return loss
def neg_iou_loss(
predict: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
return 1.0 - iou(predict, target, mask=mask)
def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor:
"""
performs safe differentiable sqrt
"""
return (torch.clamp(A, float(0)) + eps).sqrt()
def iou(
predict: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
dims = tuple(range(predict.dim())[1:])
if mask is not None:
predict = predict * mask
target = target * mask
intersect = (predict * target).sum(dims)
union = (predict + target - predict * target).sum(dims) + 1e-4
return (intersect / union).sum() / intersect.numel()
def beta_prior(pred: torch.Tensor, cap: float = 0.1) -> torch.Tensor:
if cap <= 0.0:
raise ValueError("capping should be positive to avoid unbound loss")
min_value = math.log(cap) + math.log(cap + 1.0)
return (torch.log(pred + cap) + torch.log(1.0 - pred + cap)).mean() - min_value
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import glob
import os
import shutil
import tempfile
import torch
def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats
try:
stats = Stats.load(flstats)
except:
print("Cant load stats! %s" % flstats)
stats = None
return stats
def get_model_path(fl) -> str:
fl = os.path.splitext(fl)[0]
flmodel = "%s.pth" % fl
return flmodel
def get_optimizer_path(fl) -> str:
fl = os.path.splitext(fl)[0]
flopt = "%s_opt.pth" % fl
return flopt
def get_stats_path(fl, eval_results: bool = False):
fl = os.path.splitext(fl)[0]
if eval_results:
for postfix in ("_2", ""):
flstats = os.path.join(os.path.dirname(fl), f"stats_test{postfix}.jgz")
if os.path.isfile(flstats):
break
else:
flstats = "%s_stats.jgz" % fl
# pyre-fixme[61]: `flstats` is undefined, or not always defined.
return flstats
def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
"""
This functions stores model files safely so that no model files exist on the
file system in case the saving procedure gets interrupted.
This is done first by saving the model files to a temporary directory followed
by (atomic) moves to the target location. Note, that this can still result
in a corrupt set of model files in case interruption happens while performing
the moves. It is however quite improbable that a crash would occur right at
this time.
"""
print(f"saving model files safely to {fl}")
# first store everything to a tmpdir
with tempfile.TemporaryDirectory() as tmpdir:
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
stored_tmp_fls = save_model(model, stats, tmpfl, optimizer=optimizer, cfg=cfg)
tgt_fls = [
(
os.path.join(os.path.split(fl)[0], os.path.split(tmpfl)[-1])
if (tmpfl is not None)
else None
)
for tmpfl in stored_tmp_fls
]
# then move from the tmpdir to the right location
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
if tgt_fl is None:
continue
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
shutil.move(tmpfl, tgt_fl)
def save_model(model, stats, fl, optimizer=None, cfg=None):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
print("saving model to %s" % flmodel)
torch.save(model.state_dict(), flmodel)
flopt = None
if optimizer is not None:
flopt = get_optimizer_path(fl)
print("saving optimizer to %s" % flopt)
torch.save(optimizer.state_dict(), flopt)
print("saving model stats to %s" % flstats)
stats.save(flstats)
return flstats, flmodel, flopt
def load_model(fl):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel)
stats = load_stats(flstats)
if os.path.isfile(flopt):
optimizer = torch.load(flopt)
else:
optimizer = None
return model_state_dict, stats, optimizer
def parse_epoch_from_model_path(model_path) -> int:
return int(
os.path.split(model_path)[-1].replace(".pth", "").replace("model_epoch_", "")
)
def get_checkpoint(exp_dir, epoch):
fl = os.path.join(exp_dir, "model_epoch_%08d.pth" % epoch)
return fl
def find_last_checkpoint(
exp_dir, any_path: bool = False, all_checkpoints: bool = False
):
if any_path:
exts = [".pth", "_stats.jgz", "_opt.pth"]
else:
exts = [".pth"]
for ext in exts:
fls = sorted(
glob.glob(
os.path.join(glob.escape(exp_dir), "model_epoch_" + "[0-9]" * 8 + ext)
)
)
if len(fls) > 0:
break
# pyre-fixme[61]: `fls` is undefined, or not always defined.
if len(fls) == 0:
fl = None
else:
if all_checkpoints:
# pyre-fixme[61]: `fls` is undefined, or not always defined.
fl = [f[0 : -len(ext)] + ".pth" for f in fls]
else:
fl = fls[-1][0 : -len(ext)] + ".pth"
return fl
def purge_epoch(exp_dir, epoch) -> None:
model_path = get_checkpoint(exp_dir, epoch)
for file_path in [
model_path,
get_optimizer_path(model_path),
get_stats_path(model_path),
]:
if os.path.isfile(file_path):
print("deleting %s" % file_path)
os.remove(file_path)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, cast
import torch
import torch.nn.functional as Fu
from pytorch3d.renderer import (
AlphaCompositor,
NDCMultinomialRaysampler,
PointsRasterizationSettings,
PointsRasterizer,
ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
def get_rgbd_point_cloud(
camera: CamerasBase,
image_rgb: torch.Tensor,
depth_map: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5,
mask_points: bool = True,
) -> Pointclouds:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
pixel colors.
"""
imh, imw = image_rgb.shape[2:]
# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
NDCMultinomialRaysampler(
image_width=imw,
image_height=imh,
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
)
pts_mask = depth_map > 0.0
if mask is not None:
pts_mask *= mask > mask_thr
pts_mask = pts_mask.reshape(-1)
pts_3d = pts_3d.reshape(-1, 3)[pts_mask]
pts_colors = torch.nn.functional.interpolate(
image_rgb,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
pts_colors = pts_colors.permute(0, 2, 3, 1).reshape(-1, 3)[pts_mask]
return Pointclouds(points=pts_3d[None], features=pts_colors[None])
def render_point_cloud_pytorch3d(
camera,
point_cloud,
render_size: Tuple[int, int],
point_radius: float = 0.03,
topk: int = 10,
eps: float = 1e-2,
bg_color=None,
**kwargs
):
# feature dimension
featdim = point_cloud.features_packed().shape[-1]
# move to the camera coordinates; using identity cameras in the renderer
point_cloud = _transform_points(camera, point_cloud, eps, **kwargs)
camera_trivial = camera.clone()
camera_trivial.R[:] = torch.eye(3)
camera_trivial.T *= 0.0
rasterizer = PointsRasterizer(
cameras=camera_trivial,
raster_settings=PointsRasterizationSettings(
image_size=render_size,
radius=point_radius,
points_per_pixel=topk,
bin_size=64 if int(max(render_size)) > 1024 else None,
),
)
fragments = rasterizer(point_cloud, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = rasterizer.raster_settings.radius
# set up the blending weights
dists2 = fragments.dists
weights = 1 - dists2 / (r * r)
ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()
weights = weights * ok
fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
weights_prm = weights.permute(0, 3, 1, 2)
images = AlphaCompositor()(
fragments_prm,
weights_prm,
point_cloud.features_packed().permute(1, 0),
background_color=bg_color if bg_color is not None else [0.0] * featdim,
**kwargs,
)
# get the depths ...
# weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
# cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
cumprod = torch.cumprod(1 - weights, dim=-1)
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
# add the rendering mask
render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0
# cat depths and render mask
rendered_blob = torch.cat((images, depths[:, None], render_mask[:, None]), dim=1)
# reshape back
rendered_blob = Fu.interpolate(
rendered_blob,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got `Tuple[int,
# ...]`.
size=tuple(render_size),
mode="bilinear",
)
data_rendered, depth_rendered, render_mask = rendered_blob.split(
[rendered_blob.shape[1] - 2, 1, 1],
dim=1,
)
return data_rendered, render_mask, depth_rendered
def _signed_clamp(x, eps):
sign = x.sign() + (x == 0.0).type_as(x)
x_clamp = sign * torch.clamp(x.abs(), eps)
return x_clamp
def _transform_points(cameras, point_clouds, eps, **kwargs):
pts_world = point_clouds.points_padded()
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world, eps=eps
)
# it is crucial to actually clamp the points as well ...
pts_view = torch.cat(
(pts_view[..., :-1], _signed_clamp(pts_view[..., -1:], eps)), dim=-1
)
point_clouds = point_clouds.update_padded(pts_view)
return point_clouds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.structures import Pointclouds
from .point_cloud_utils import render_point_cloud_pytorch3d
def rasterize_mc_samples(
xys: torch.Tensor,
feats: torch.Tensor,
image_size_hw: Tuple[int, int],
radius: float = 0.03,
topk: int = 5,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Rasterizes Monte-Carlo sampled features back onto the image.
Specifically, the code uses the PyTorch3D point rasterizer to render
a z-flat point cloud composed of the xy MC locations and their features.
Args:
xys: B x N x 2 2D point locations in PyTorch3D NDC convention
feats: B x N x dim tensor containing per-point rendered features.
image_size_hw: Tuple[image_height, image_width] containing
the size of rasterized image.
radius: Rasterization point radius.
topk: The maximum z-buffer size for the PyTorch3D point cloud rasterizer.
masks: B x N x 1 tensor containing the alpha mask of the
rendered features.
"""
if masks is None:
masks = torch.ones_like(xys[..., :1])
feats = torch.cat((feats, masks), dim=-1)
pointclouds = Pointclouds(
points=torch.cat([xys, torch.ones_like(xys[..., :1])], dim=-1),
features=feats,
)
data_rendered, render_mask, _ = render_point_cloud_pytorch3d(
PerspectiveCameras(device=feats.device),
pointclouds,
render_size=image_size_hw,
point_radius=radius,
topk=topk,
)
data_rendered, masks_pt = data_rendered.split(
[data_rendered.shape[1] - 1, 1], dim=1
)
render_mask = masks_pt * render_mask
return data_rendered, render_mask
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import json
import time
import warnings
from collections.abc import Iterable
from itertools import cycle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors as mcolors
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.history = []
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1, epoch=0):
# make sure the history is of the same len as epoch
while len(self.history) <= epoch:
self.history.append([])
self.history[epoch].append(val / n)
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_epoch_averages(self, epoch=-1):
if len(self.history) == 0: # no stats here
return None
elif epoch == -1:
return [
(float(np.array(x).mean()) if len(x) > 0 else float("NaN"))
for x in self.history
]
else:
return float(np.array(self.history[epoch]).mean())
def get_all_values(self):
all_vals = [np.array(x) for x in self.history]
all_vals = np.concatenate(all_vals)
return all_vals
def get_epoch(self):
return len(self.history)
@staticmethod
def from_json_str(json_str):
self = AverageMeter()
self.__dict__.update(json.loads(json_str))
return self
class Stats(object):
# TODO: update this with context manager
"""
stats logging object useful for gathering statistics of training a deep net in pytorch
Example:
# init stats structure that logs statistics 'objective' and 'top1e'
stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=nueral network)
dataloader = init_dataloader() # init a dataloader
for epoch in range(10):
# start of epoch -> call new_epoch
stats.new_epoch()
# iterate over batches
for batch in dataloader:
output = network(batch) # run and save into a dict of output variables "output"
# stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db
stats.update(output)
stats.print() # prints the averages over given epoch
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
"""
def __init__(
self,
log_vars,
verbose=False,
epoch=-1,
visdom_env="main",
do_plot=True,
plot_file=None,
visdom_server="http://localhost",
visdom_port=8097,
):
self.verbose = verbose
self.log_vars = log_vars
self.visdom_env = visdom_env
self.visdom_server = visdom_server
self.visdom_port = visdom_port
self.plot_file = plot_file
self.do_plot = do_plot
self.hard_reset(epoch=epoch)
@staticmethod
def from_json_str(json_str):
self = Stats([])
# load the global state
self.__dict__.update(json.loads(json_str))
# recover the AverageMeters
for stat_set in self.stats:
self.stats[stat_set] = {
log_var: AverageMeter.from_json_str(log_vals_json_str)
for log_var, log_vals_json_str in self.stats[stat_set].items()
}
return self
@staticmethod
def load(flpath, postfix=".jgz"):
flpath = _get_postfixed_filename(flpath, postfix)
with gzip.open(flpath, "r") as fin:
data = json.loads(fin.read().decode("utf-8"))
return Stats.from_json_str(data)
def save(self, flpath, postfix=".jgz"):
flpath = _get_postfixed_filename(flpath, postfix)
# store into a gzipped-json
with gzip.open(flpath, "w") as fout:
fout.write(json.dumps(self, cls=StatsJSONEncoder).encode("utf-8"))
# some sugar to be used with "with stats:" at the beginning of the epoch
def __enter__(self):
if self.do_plot and self.epoch >= 0:
self.plot_stats(self.visdom_env)
self.new_epoch()
def __exit__(self, type, value, traceback):
iserr = type is not None and issubclass(type, Exception)
iserr = iserr or (type is KeyboardInterrupt)
if iserr:
print("error inside 'with' block")
return
if self.do_plot:
self.plot_stats(self.visdom_env)
def reset(self): # to be called after each epoch
stat_sets = list(self.stats.keys())
if self.verbose:
print("stats: epoch %d - reset" % self.epoch)
self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets:
for stat in self.stats[stat_set]:
self.stats[stat_set][stat].reset()
def hard_reset(self, epoch=-1): # to be called during object __init__
self.epoch = epoch
if self.verbose:
print("stats: epoch %d - hard reset" % self.epoch)
self.stats = {}
# reset
self.reset()
def new_epoch(self):
if self.verbose:
print("stats: new epoch %d" % (self.epoch + 1))
self.epoch += 1
self.reset() # zero the stats + increase epoch counter
def gather_value(self, val):
if isinstance(val, (float, int)):
val = float(val)
else:
val = val.data.cpu().numpy()
val = float(val.sum())
return val
def add_log_vars(self, added_log_vars, verbose=True):
for add_log_var in added_log_vars:
if add_log_var not in self.stats:
if verbose:
print(f"Adding {add_log_var}")
self.log_vars.append(add_log_var)
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized
print(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
)
self.new_epoch()
if stat_set not in self.stats:
self.stats[stat_set] = {}
self.it[stat_set] = -1
if not freeze_iter:
self.it[stat_set] += 1
epoch = self.epoch
it = self.it[stat_set]
for stat in self.log_vars:
if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter()
if stat == "sec/it": # compute speed
if time_start is None:
elapsed = 0.0
else:
elapsed = time.time() - time_start
time_per_it = float(elapsed) / float(it + 1)
val = time_per_it
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
else:
if stat in preds:
try:
val = self.gather_value(preds[stat])
except KeyError:
raise ValueError(
"could not extract prediction %s\
from the prediction dictionary"
% stat
)
else:
val = None
if val is not None:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def get_epoch_averages(self, epoch=None):
stat_sets = list(self.stats.keys())
if epoch is None:
epoch = self.epoch
if epoch == -1:
epoch = list(range(self.epoch))
outvals = {}
for stat_set in stat_sets:
outvals[stat_set] = {
"epoch": epoch,
"it": self.it[stat_set],
"epoch_max": self.epoch,
}
for stat in self.stats[stat_set].keys():
if self.stats[stat_set][stat].count == 0:
continue
if isinstance(epoch, Iterable):
avgs = self.stats[stat_set][stat].get_epoch_averages()
avgs = [avgs[e] for e in epoch]
else:
avgs = self.stats[stat_set][stat].get_epoch_averages(epoch=epoch)
outvals[stat_set][stat] = avgs
return outvals
def print(
self,
max_it=None,
stat_set="train",
vars_print=None,
get_str=False,
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
epoch = self.epoch
stats = self.stats
str_out = ""
it = self.it[stat_set]
stat_str = ""
stats_print = sorted(stats[stat_set].keys())
for stat in stats_print:
if stats[stat_set][stat].count == 0:
continue
if skip_nan and not np.isfinite(stats[stat_set][stat].avg):
continue
stat_str += " {0:.12}: {1:1.3f} |".format(
stat_format(stat), stats[stat_set][stat].avg
)
head_str = "[%s] | epoch %3d | it %5d" % (stat_set, epoch, it)
if max_it:
head_str += "/ %d" % max_it
str_out = "%s | %s" % (head_str, stat_str)
if get_str:
return str_out
else:
print(str_out)
def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
):
# use the cached visdom env if none supplied
if visdom_env is None:
visdom_env = self.visdom_env
if visdom_server is None:
visdom_server = self.visdom_server
if visdom_port is None:
visdom_port = self.visdom_port
if plot_file is None:
plot_file = self.plot_file
stat_sets = list(self.stats.keys())
print(
"printing charts to visdom env '%s' (%s:%d)"
% (visdom_env, visdom_server, visdom_port)
)
novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True
lines = []
# plot metrics
if not novisdom:
viz.close(env=visdom_env, win=None)
for stat in self.log_vars:
vals = []
stat_sets_now = []
for stat_set in stat_sets:
val = self.stats[stat_set][stat].get_epoch_averages()
if val is None:
continue
else:
val = np.array(val).reshape(-1)
stat_sets_now.append(stat_set)
vals.append(val)
if len(vals) == 0:
continue
lines.append((stat_sets_now, stat, vals))
if not novisdom:
for tmodes, stat, vals in lines:
title = "%s" % stat
opts = {"title": title, "legend": list(tmodes)}
for i, (tmode, val) in enumerate(zip(tmodes, vals)):
update = "append" if i > 0 else None
valid = np.where(np.isfinite(val))[0]
if len(valid) == 0:
continue
x = np.arange(len(val))
viz.line(
Y=val[valid],
X=x[valid],
env=visdom_env,
opts=opts,
win=f"stat_plot_{title}",
name=tmode,
update=update,
)
if plot_file:
print("exporting stats to %s" % plot_file)
ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5})
color = cycle(plt.cm.tab10(np.linspace(0, 1, 10)))
fig = plt.figure(1)
plt.clf()
for idx, (tmodes, stat, vals) in enumerate(lines):
c = next(color)
plt.subplot(nrow, ncol, idx + 1)
plt.gca()
for vali, vals_ in enumerate(vals):
c_ = c * (1.0 - float(vali) * 0.3)
valid = np.where(np.isfinite(vals_))[0]
if len(valid) == 0:
continue
x = np.arange(len(vals_))
plt.plot(x[valid], vals_[valid], c=c_, linewidth=1)
plt.ylabel(stat)
plt.xlabel("epoch")
plt.gca().yaxis.label.set_color(c[0:3] * 0.75)
plt.legend(tmodes)
gcolor = np.array(mcolors.to_rgba("lightgray"))
plt.grid(
b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4
)
plt.grid(
b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2
)
plt.minorticks_on()
plt.tight_layout()
plt.show()
try:
fig.savefig(plot_file)
except PermissionError:
warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
stat_sets = list(self.stats.keys())
# remove the additional log_vars
for stat_set in stat_sets:
for stat in self.stats[stat_set].keys():
if stat not in log_vars:
print("additional stat %s:%s -> removing" % (stat_set, stat))
self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
}
self.log_vars = log_vars # !!!
for stat_set in stat_sets:
reference_stat = list(self.stats[stat_set].keys())[0]
for stat in log_vars:
if stat not in self.stats[stat_set]:
if verbose:
print(
"missing stat %s:%s -> filling with default values (%1.2f)"
% (stat_set, stat, default_val)
)
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
h = self.stats[stat_set][stat].history
if len(h) == 0: # just never updated stat ... skip
continue
else:
if verbose:
print(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
% (stat_set, stat, default_val)
)
else:
continue
self.stats[stat_set][stat] = AverageMeter()
self.stats[stat_set][stat].reset()
lastep = self.epoch + 1
for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
epoch_generated = self.stats[stat_set][stat].get_epoch()
assert (
epoch_self == epoch_generated
), "bad epoch of synchronized log_var! %d vs %d" % (
epoch_self,
epoch_generated,
)
class StatsJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, (AverageMeter, Stats)):
enc = self.encode(o.__dict__)
return enc
else:
raise TypeError(
f"Object of type {o.__class__.__name__} " f"is not JSON serializable"
)
def _get_postfixed_filename(fl, postfix):
return fl if fl.endswith(postfix) else fl + postfix
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import collections
import dataclasses
import time
from contextlib import contextmanager
from typing import Any, Callable, Dict
import torch
@contextmanager
def evaluating(net: torch.nn.Module):
"""Temporarily switch to evaluation mode."""
istrain = net.training
try:
net.eval()
yield net
finally:
if istrain:
net.train()
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.cuda()
except AttributeError:
pass
return t
def try_to_cpu(t: Any) -> Any:
"""
Try to move the input variable `t` to a cpu device.
Args:
t: Input.
Returns:
t_cpu: `t` moved to a cpu device, if supported.
"""
try:
t = t.cpu()
except AttributeError:
pass
return t
def dict_to_cuda(batch: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Move all values in a dictionary to cuda if supported.
Args:
batch: Input dict.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
return {k: try_to_cuda(v) for k, v in batch.items()}
def dict_to_cpu(batch):
"""
Move all values in a dictionary to cpu if supported.
Args:
batch: Input dict.
Returns:
batch_cpu: `batch` moved to a cpu device, if supported.
"""
return {k: try_to_cpu(v) for k, v in batch.items()}
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
def dataclass_to_cpu_(obj):
"""
Move all contents of a dataclass to cpu inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cpu device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cpu(getattr(obj, f.name)))
return obj
# TODO: test it
def cat_dataclass(batch, tensor_collator: Callable):
"""
Concatenate all fields of a list of dataclasses `batch` to a single
dataclass object using `tensor_collator`.
Args:
batch: Input list of dataclasses.
Returns:
concatenated_batch: All elements of `batch` concatenated to a single
dataclass object.
tensor_collator: The function used to concatenate tensor fields.
"""
elem = batch[0]
collated = {}
for f in dataclasses.fields(elem):
elem_f = getattr(elem, f.name)
if elem_f is None:
collated[f.name] = None
elif torch.is_tensor(elem_f):
collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch])
elif dataclasses.is_dataclass(elem_f):
collated[f.name] = cat_dataclass(
[getattr(e, f.name) for e in batch], tensor_collator
)
elif isinstance(elem_f, collections.abc.Mapping):
collated[f.name] = {
k: tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None
else None
for k in elem_f
}
else:
raise ValueError("Unsupported field type for concatenation")
return type(elem)(**collated)
class Timer:
"""
A simple class for timing execution.
Example:
```
with Timer():
print("This print statement is timed.")
```
"""
def __init__(self, name="timer", quiet=False):
self.name = name
self.quiet = quiet
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
self.end = time.time()
self.interval = self.end - self.start
if not self.quiet:
print("%20s: %1.6f sec" % (self.name, self.interval))
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import tempfile
import warnings
from typing import Optional, Tuple, Union
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
matplotlib.use("Agg")
class VideoWriter:
"""
A class for exporting videos.
"""
def __init__(
self,
cache_dir: Optional[str] = None,
ffmpeg_bin: str = "ffmpeg",
out_path: str = "/tmp/video.mp4",
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
**kwargs,
):
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
a temporary directory will be used.
ffmpeg_bin: The path to an `ffmpeg` executable.
out_path: The path to the output video.
fps: The speed of the generated video in frames-per-second.
output_format: Format of the output video. Currently only `"visdom"`
is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty.
"""
self.rmdir_allowed = rmdir_allowed
self.output_format = output_format
self.fps = fps
self.out_path = out_path
self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin
self.frames = []
self.regexp = "frame_%08d.png"
self.frame_num = 0
if self.cache_dir is not None:
self.tmp_dir = None
if os.path.isdir(self.cache_dir):
if rmdir_allowed:
shutil.rmtree(self.cache_dir)
else:
warnings.warn(
f"Warning: cache directory not empty ({self.cache_dir})."
)
os.makedirs(self.cache_dir, exist_ok=True)
else:
self.tmp_dir = tempfile.TemporaryDirectory()
self.cache_dir = self.tmp_dir.name
def write_frame(
self,
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
resize: Optional[Union[float, Tuple[int, int]]] = None,
):
"""
Write a frame to the video.
Args:
frame: An object containing the frame image.
resize: Either a floating defining the image rescaling factor
or a 2-tuple defining the size of the output image.
"""
outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num)
if isinstance(frame, matplotlib.figure.Figure):
plt.savefig(outfile)
im = Image.open(outfile)
elif isinstance(frame, np.ndarray):
if frame.dtype in (np.float64, np.float32, float):
frame = (np.transpose(frame, (1, 2, 0)) * 255.0).astype(np.uint8)
im = Image.fromarray(frame)
elif isinstance(frame, Image.Image):
im = frame
elif isinstance(frame, str):
im = Image.open(frame).convert("RGB")
else:
raise ValueError("Cant convert type %s" % str(type(frame)))
if im is not None:
if resize is not None:
if isinstance(resize, float):
resize = [int(resize * s) for s in im.size]
else:
resize = im.size
# make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile)
self.frames.append(outfile)
self.frame_num += 1
def get_video(self, quiet: bool = True):
"""
Generate the video from the written frames.
Args:
quiet: If `True`, suppresses logging messages.
Returns:
video_path: The path to the generated video.
"""
regexp = os.path.join(self.cache_dir, self.regexp)
if self.output_format == "visdom": # works for ppt too
ffmcmd_ = (
"%s -r %d -i %s -vcodec h264 -f mp4 \
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
)
else:
raise ValueError("no such output type %s" % str(self.output_format))
if quiet:
ffmcmd_ += " > /dev/null 2>&1"
else:
print(ffmcmd_)
os.system(ffmcmd_)
return self.out_path
def __del__(self):
if self.tmp_dir is not None:
self.tmp_dir.cleanup()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
import torch
from visdom import Visdom
def get_visdom_env(cfg):
"""
Parse out visdom environment name from the input config.
Args:
cfg: The global config file.
Returns:
visdom_env: The name of the visdom environment.
"""
if len(cfg.visdom_env) == 0:
visdom_env = cfg.exp_dir.split("/")[-1]
else:
visdom_env = cfg.visdom_env
return visdom_env
# TODO: a proper singleton
_viz_singleton = None
def get_visdom_connection(
server: str = "http://localhost",
port: int = 8097,
) -> Visdom:
"""
Obtain a connection to a visdom server.
Args:
server: Server address.
port: Server port.
Returns:
connection: The connection object.
"""
global _viz_singleton
if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port)
return _viz_singleton
def visualize_basics(
viz: Visdom,
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
visualize_preds_keys: List[str] = [
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
],
store_history: bool = False,
) -> None:
"""
Visualize basic outputs of a `GenericModel` to visdom.
Args:
viz: The visdom object.
preds: A dictionary containing `GenericModel` outputs.
visdom_env_imgs: Target visdom environment name.
title: The title of produced visdom window.
visualize_preds_keys: The list of keys of `preds` for visualization.
store_history: Store the history buffer in visdom windows.
"""
imout = {}
for k in visualize_preds_keys:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
# divide by 95th percentile
normfac = (
v.view(v.shape[0], -1)
.topk(k=int(0.05 * (v.numel() // v.shape[0])), dim=-1)
.values[:, -1]
)
v = v / normfac[:, None, None, None].clamp(1e-4)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
v = torch.nn.functional.interpolate(
v,
# pyre-fixme[6]: Expected `Optional[typing.List[float]]` for 2nd param
# but got `float`.
scale_factor=(
600.0
if (
"_eval" in visdom_env_imgs
and k in ("images_render", "depths_render")
)
else 200.0
)
/ v.shape[2],
mode="bilinear",
)
imout[k] = v
# TODO: handle errors on the outside
try:
imout = {"all": torch.cat(list(imout.values()), dim=2)}
except:
print("cant cat!")
for k, v in imout.items():
viz.images(
v.clamp(0.0, 1.0),
win=k,
env=visdom_env_imgs,
opts={"title": title + "_" + k, "store_history": store_history},
)
def make_depth_image(
depths: torch.Tensor,
masks: torch.Tensor,
max_quantile: float = 0.98,
min_quantile: float = 0.02,
min_out_depth: float = 0.1,
max_out_depth: float = 0.9,
) -> torch.Tensor:
"""
Convert a batch of depth maps to a grayscale image.
Args:
depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks.
max_quantile: The quantile of the input depth values which will
be mapped to `max_out_depth`.
min_quantile: The quantile of the input depth values which will
be mapped to `min_out_depth`.
min_out_depth: The minimal value in each depth map will be assigned this color.
max_out_depth: The maximal value in each depth map will be assigned this color.
Returns:
depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale
depth images.
"""
normfacs = []
for d, m in zip(depths, masks):
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
if ok.sum() <= 1:
print("empty depth!")
normfacs.append(torch.zeros(2).type_as(depths))
continue
dok = d.view(-1)[ok].view(-1)
_maxk = max(int(round((1 - max_quantile) * (dok.numel()))), 1)
_mink = max(int(round(min_quantile * (dok.numel()))), 1)
normfac_max = dok.topk(k=_maxk, dim=-1).values[-1]
normfac_min = dok.topk(k=_mink, dim=-1, largest=False).values[-1]
normfacs.append(torch.stack([normfac_min, normfac_max]))
normfacs = torch.stack(normfacs)
_min, _max = (normfacs[:, 0].view(-1, 1, 1, 1), normfacs[:, 1].view(-1, 1, 1, 1))
depths = (depths - _min) / (_max - _min).clamp(1e-4)
depths = (
(depths * (max_out_depth - min_out_depth) + min_out_depth) * masks.float()
).clamp(0.0, 1.0)
return depths
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from pathlib import Path
from typing import Generator, Tuple
from zipfile import ZipFile
from iopath.common.file_io import PathManager
@contextlib.contextmanager
def get_skateboard_data(
avoid_manifold: bool = False, silence_logs: bool = False
) -> Generator[Tuple[str, PathManager], None, None]:
"""
Context manager for accessing Co3D dataset by tests, at least for
the first 5 skateboards. Internally, we want this to exercise the
normal way to access the data directly manifold, but on an RE
worker this is impossible so we use a workaround.
Args:
avoid_manifold: Use the method used by RE workers even locally.
silence_logs: Whether to reduce log output from iopath library.
Yields:
dataset_root: (str) path to dataset root.
path_manager: path_manager to access it with.
"""
path_manager = PathManager()
if silence_logs:
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
if not os.environ.get("FB_TEST", False):
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
raise unittest.SkipTest("Unknown environment. Data not available.")
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", path_manager
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
from libfb.py.parutil import get_file_path
par_path = "skateboard_first_5"
source = get_file_path(par_path)
assert Path(source).is_file()
with tempfile.TemporaryDirectory() as dest:
with ZipFile(source) as f:
f.extractall(dest)
yield os.path.join(dest, "extracted"), path_manager
else:
from iopath.fb.manifold import ManifoldPathHandler
path_manager.register_handler(ManifoldPathHandler())
yield "manifold://co3d/tree/extracted", path_manager
def provide_lpips_vgg():
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16
"""
# In OSS, torchvision looks for vgg16 weights in
# https://download.pytorch.org/models/vgg16-397923af.pth
# Inside fbcode, this is replaced by asking iopath for
# manifold://torchvision/tree/models/vgg16-397923af.pth
# (the code for this replacement is in
# fbcode/pytorch/vision/fb/_internally_replaced_utils.py )
#
# iopath does this by looking for the file at the cache location
# and if it is not there getting it from manifold.
# (the code for this is in
# fbcode/fair_infra/data/iopath/iopath/fb/manifold.py )
#
# On the remote execution worker, manifold is inaccessible.
# We solve this by making the cached file available before iopath
# looks.
#
# By default the cache location is
# ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# But we can't write to the home directory on the RE worker.
# We define FVCORE_CACHE to change the cache location to
# iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# (Without it, manifold caches in unstable temporary locations on RE.)
#
# The file we want has been copied from
# tree/models/vgg16-397923af.pth in the torchvision bucket
# to
# tree/testing/vgg16-397923af.pth in the co3d bucket
# and the TARGETS file copies it somewhere in the PAR which we
# recover with get_file_path.
# (It can't copy straight to a nested location, see
# https://fb.workplace.com/groups/askbuck/posts/2644615728920359/)
# Here we symlink it to the new cache location.
if os.environ.get("INSIDE_RE_WORKER") is not None:
from libfb.py.parutil import get_file_path
os.environ["FVCORE_CACHE"] = "iopath_cache"
par_path = "vgg_weights_for_lpips"
source = Path(get_file_path(par_path))
assert source.is_file()
dest = Path("iopath_cache/manifold_cache/tree/models")
if not dest.exists():
dest.mkdir(parents=True)
(dest / "vgg16-397923af.pth").symlink_to(source)
mask_images: true
mask_depths: true
render_image_width: 400
render_image_height: 400
mask_threshold: 0.5
output_rasterized_mc: false
bg_color:
- 0.0
- 0.0
- 0.0
view_pool: false
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
renderer_class_type: LSTMRenderer
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
implicit_function_class_type: IdrFeatureField
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_args:
image_width: 400
image_height: 400
scene_center:
- 0.0
- 0.0
- 0.0
scene_extent: 0.0
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.1
max_depth: 8.0
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
verbose: false
image_feature_extractor_args:
name: resnet34
pretrained: true
stages:
- 1
- 2
- 3
- 4
normalize_image: true
image_rescale: 0.16
first_max_pool: true
proj_dim: 32
l2_norm: true
add_masks: true
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
frame_timestamp: float = 0.0
class MockDataset:
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
}
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
# where n - 4 == max_frame_gap
frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap))
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
def get_frame_numbers_and_timestamps(self, idxs):
out = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):
self.dataset_overfit = MockDataset(1)
def test_overfit(self):
num_batches = 3
batch_size = 10
sampler = SceneBatchSampler(
self.dataset_overfit,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=[10], # will try to sample batch_size anyway
)
self.assertEqual(len(sampler), num_batches)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch)
self.assertEqual(len(batch), batch_size) # true for our examples
self.assertTrue(all(idx // 10 == 0 for idx in batch))
with self.assertRaises(StopIteration):
batch = next(it)
def test_multiseq(self):
for ips_options in [[10], [2], [3], [2, 3, 4]]:
for sample_consecutive_frames in [True, False]:
for consecutive_frames_max_gap in [0, 1, 3]:
self._test_multiseq_flavour(
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
)
def test_multiseq_gaps(self):
num_batches = 16
batch_size = 10
dataset_multiseq = MockDataset(5, max_frame_gap=3)
for ips_options in [[10], [2], [3], [2, 3, 4]]:
debug_info = f" Images per sequence: {ips_options}."
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=True,
consecutive_frames_max_gap=1,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
if max(ips_options) > 5:
# true for our examples
self.assertEqual(len(batch), 5, msg=debug_info)
else:
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
self._check_frames_are_consecutive(
batch, dataset_multiseq.frame_annots, debug_info
)
def _test_multiseq_flavour(
self,
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
num_batches=16,
batch_size=10,
):
debug_info = (
f" Images per sequence: {ips_options}, "
f"sample_consecutive_frames: {sample_consecutive_frames}, "
f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, "
)
# in this test, either consecutive_frames_max_gap == max_frame_gap,
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3
dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap)
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
typical_counts = set()
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
# find distribution over sequences
counts = _count_by_quotient(batch, 10)
freqs = _count_by_quotient(counts.values(), 1)
self.assertLessEqual(
len(freqs),
2,
msg="We should have maximum of 2 different "
"frequences of sequences in the batch." + debug_info,
)
if len(freqs) == 2:
most_seq_count = max(*freqs.keys())
last_seq = min(*freqs.keys())
self.assertEqual(
freqs[last_seq],
1,
msg="Only one odd sequence allowed." + debug_info,
)
else:
self.assertEqual(len(freqs), 1)
most_seq_count = next(iter(freqs))
self.assertIn(most_seq_count, ips_options)
typical_counts.add(most_seq_count)
if sample_consecutive_frames:
self._check_frames_are_consecutive(
batch,
dataset_multiseq.frame_annots,
debug_info,
max_gap=consecutive_frames_max_gap,
)
self.assertTrue(
all(i in typical_counts for i in ips_options),
"Some of the frequency options did not occur among "
f"the {num_batches} batches (could be just bad luck)." + debug_info,
)
with self.assertRaises(StopIteration):
batch = next(it)
def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1):
# make sure that sampled frames are consecutive
for i in range(len(batch) - 1):
curr_idx, next_idx = batch[i : i + 2]
if curr_idx // 10 == next_idx // 10: # same sequence
if max_gap > 0:
curr_idx, next_idx = [
annots[idx]["frame_annotation"].frame_number
for idx in (curr_idx, next_idx)
]
gap = max_gap
else:
gap = 1 # we'll check that raw dataset indices are consecutive
self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info)
def _count_by_quotient(indices, divisor):
counter = defaultdict(int)
for i in indices:
counter[i // divisor] += 1
return counter
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from math import pi
import torch
from pytorch3d.implicitron.tools.circle_fitting import (
_signed_area,
fit_circle_in_2d,
fit_circle_in_3d,
)
from pytorch3d.transforms import random_rotation
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestCircleFitting(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def _assertParallel(self, a, b, **kwargs):
"""
Given a and b of shape (..., 3) each containing 3D vectors,
assert that correspnding vectors are parallel. Changed sign is ok.
"""
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
def test_simple_3d(self):
device = torch.device("cuda:0")
for _ in range(7):
radius = 10 * torch.rand(1, device=device)[0]
center = 10 * torch.rand(3, device=device)
rot = random_rotation(device=device)
offset = torch.rand(3, device=device)
up = torch.rand(3, device=device)
self._simple_3d_test(radius, center, rot, offset, up)
def _simple_3d_test(self, radius, center, rot, offset, up):
# angles are increasing so the points move in a well defined direction.
angles = torch.cumsum(torch.rand(17, device=rot.device), dim=0)
many = torch.stack(
[torch.cos(angles), torch.sin(angles), torch.zeros_like(angles)], dim=1
)
source_points = (many * radius) @ rot + center[None]
# case with no generation
result = fit_circle_in_3d(source_points)
self.assertClose(result.radius, radius)
self.assertClose(result.center, center)
self._assertParallel(result.normal, rot[2], atol=1e-5)
self.assertEqual(result.generated_points.shape, (0, 3))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_3d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, radius)
self.assertClose(result2.center, center)
self.assertClose(result2.normal, result.normal)
self.assertEqual(result2.generated_points.shape, (5, 3))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4], atol=1e-4)
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
observed_normal = torch.cross(
observed_points[0] - observed_points[2],
observed_points[1] - observed_points[3],
dim=-1,
)
self._assertParallel(observed_normal, result.normal, atol=1e-4)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(
torch.norm(diameters, dim=1), diameters.new_full((2,), 2 * radius)
)
# Regenerate the input points
result3 = fit_circle_in_3d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, radius)
self.assertClose(result3.center, center)
self.assertClose(result3.normal, result.normal)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
# Test with offset
result4 = fit_circle_in_3d(
source_points, angles=angles - angles[0], offset=offset, up=up
)
self.assertClose(result4.radius, radius)
self.assertClose(result4.center, center)
self.assertClose(result4.normal, result.normal)
observed_offsets = result4.generated_points - source_points
# observed_offset is constant
self.assertClose(
observed_offsets.min(0).values, observed_offsets.max(0).values, atol=1e-5
)
# observed_offset has the right length
self.assertClose(observed_offsets[0].norm(), offset.norm())
self.assertClose(result.normal.norm(), torch.ones(()))
# component of observed_offset along normal
component = torch.dot(observed_offsets[0], result.normal)
self.assertClose(component.abs(), offset[2].abs(), atol=1e-5)
agree_normal = torch.dot(result.normal, up) > 0
agree_signs = component * offset[2] > 0
self.assertEqual(agree_normal, agree_signs)
def test_simple_2d(self):
radius = 7.0
center = torch.tensor([9, 2.5])
angles = torch.cumsum(torch.rand(17), dim=0)
many = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
source_points = (many * radius) + center[None]
result = fit_circle_in_2d(source_points)
self.assertClose(result.radius, torch.tensor(radius))
self.assertClose(result.center, center)
self.assertEqual(result.generated_points.shape, (0, 2))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_2d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, torch.tensor(radius))
self.assertClose(result2.center, center)
self.assertEqual(result2.generated_points.shape, (5, 2))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4])
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(torch.norm(diameters, dim=1), torch.full((2,), 2 * radius))
# Regenerate the input points
result3 = fit_circle_in_2d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, torch.tensor(radius))
self.assertClose(result3.center, center)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
def test_minimum_inputs(self):
fit_circle_in_3d(torch.rand(3, 3), n_points=10)
with self.assertRaisesRegex(
ValueError, "2 points are not enough to determine a circle"
):
fit_circle_in_3d(torch.rand(2, 3))
def test_signed_area(self):
n_points = 1001
angles = torch.linspace(0, 2 * pi, n_points)
radius = 0.85
center = torch.rand(2)
circle = center + radius * torch.stack(
[torch.cos(angles), torch.sin(angles)], dim=1
)
circle_area = torch.tensor(pi * radius * radius)
self.assertClose(_signed_area(circle), circle_area)
# clockwise is negative
self.assertClose(_signed_area(circle.flip(0)), -circle_area)
# Semicircles
self.assertClose(_signed_area(circle[: (n_points + 1) // 2]), circle_area / 2)
self.assertClose(_signed_area(circle[n_points // 2 :]), circle_area / 2)
# A straight line bounds no area
self.assertClose(_signed_area(torch.rand(2, 2)), torch.tensor(0.0))
# Letter 'L' written anticlockwise.
L_shape = [[0, 1], [0, 0], [1, 0]]
# Triangle area is 0.5 * b * h.
self.assertClose(_signed_area(torch.tensor(L_shape)), torch.tensor(0.5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment