"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "0df04f33b3c09e3d5f0b33f71d69d2c9180f2cbb"
Commit 82895ec3 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

swtich to scipy's kabsch algorithm

parent 5621ac05
...@@ -37,6 +37,7 @@ from openfold.utils.tensor_utils import ( ...@@ -37,6 +37,7 @@ from openfold.utils.tensor_utils import (
import random import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
from scipy import spatial as sp_spatial
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
...@@ -1676,47 +1677,21 @@ def chain_center_of_mass_loss( ...@@ -1676,47 +1677,21 @@ def chain_center_of_mass_loss(
# # # #
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
Using the Kabsch algorithm with two sets of paired point P and Q, centered Use scipy.spatial package to calculate best rotation that minimises
around the centroid. Each vector set is represented as an NxD the RMSD betwee P and Q
matrix, where D is the the dimension of the space.
The algorithm works in three steps: Args:
- a centroid translation of P and Q (assumed done before this function P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
call) Q: [N * 3] the same dimension as P
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix return:
P,Q = P.to('cpu'),Q.to('cpu') # move to cpu memory just in case it takes up too much gpu mem A 3*3 rotation matrix
C = P.transpose(-1, -2) @ Q """
assert P.shape == torch.size([Q.shape[0],Q.shape[1]])
# Computation of the optimal rotation matrix rotation,_ = sp_spatial.transform.Rotation.align_vectors(P.numpy(),Q.numpy())
# This can be done using singular value decomposition (SVD) rotation = torch.tensor(rotation,dtype=torch.float64)
# Getting the sign of the det(V)*(W) to decide assert rotation.shape == torch.size([3,3])
# whether we need to correct our rotation matrix to ensure a return rotation
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, _, W = torch.linalg.svd(C.to('cpu'))
d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0
if d:
V[:, -1] = -V[:, -1]
# Create Rotation matrix U
U = V @ W
return U
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment