Commit db1f7c45 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

avoid symeig

Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch.

Reviewed By: patricklabatut

Differential Revision: D34375784

fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
parent 59972b12
......@@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
# PyTorch version >= 1.9
return torch.linalg.qr(A)
return torch.qr(A)
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvalues=True)
......@@ -16,6 +16,7 @@ from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import eigh
from pytorch3d.ops import points_alignment, utils as oputil
......@@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = torch.symeig(mTm, eigenvectors=True)
s, v = eigh(mTm)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
......
......@@ -7,8 +7,9 @@
from typing import TYPE_CHECKING, Tuple, Union
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.common.workaround import symeig3x3
from ..common.workaround import symeig3x3
from .utils import convert_pointclouds_to_tensor, get_point_covariances
......@@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(
# get the local coord frames as principal directions of
# the per-point covariance
# this is done with torch.symeig, which returns the
# this is done with torch.symeig / torch.linalg.eigh, which returns the
# eigenvectors (=principal directions) in an ascending order of their
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
# corresponds to the normal direction
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
# corresponds to the normal direction; or with a custom equivalent.
if use_symeig_workaround:
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
else:
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
curvatures, local_coord_frames = eigh(cov)
# disambiguate the directions of individual principal vectors
if disambiguate_directions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment