".github/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "5944dbed0b6e08c6eeba9d8ade9bcbbc432da2f9"
Commit 458a62f7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

switch optimal alignment method to procustes package

parent 82895ec3
...@@ -37,7 +37,7 @@ from openfold.utils.tensor_utils import ( ...@@ -37,7 +37,7 @@ from openfold.utils.tensor_utils import (
import random import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
from scipy import spatial as sp_spatial import procrustes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
...@@ -1680,6 +1680,10 @@ def kabsch_rotation(P, Q): ...@@ -1680,6 +1680,10 @@ def kabsch_rotation(P, Q):
Use scipy.spatial package to calculate best rotation that minimises Use scipy.spatial package to calculate best rotation that minimises
the RMSD betwee P and Q the RMSD betwee P and Q
The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here:
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args: Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P Q: [N * 3] the same dimension as P
...@@ -1687,10 +1691,24 @@ def kabsch_rotation(P, Q): ...@@ -1687,10 +1691,24 @@ def kabsch_rotation(P, Q):
return: return:
A 3*3 rotation matrix A 3*3 rotation matrix
""" """
assert P.shape == torch.size([Q.shape[0],Q.shape[1]])
rotation,_ = sp_spatial.transform.Rotation.align_vectors(P.numpy(),Q.numpy()) assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
rotation = torch.tensor(rotation,dtype=torch.float64) finished_rotation = False
assert rotation.shape == torch.size([3,3]) while not finished_rotation:
#
# Add a try-except block cuz sometimes SVD fails to converge and crashes the programme
# Will continue trying SVD until the optimal rotaion is calculated
# #
try:
rotation = procrustes.rotational(P.to('cpu').numpy(),
Q.to('cpu').numpy(),translate=True)
finished_rotation = True
except:
print(f"svd failed.")
import sys
sys.exit()
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3])
return rotation return rotation
def get_optimal_transform( def get_optimal_transform(
...@@ -1700,6 +1718,12 @@ def get_optimal_transform( ...@@ -1700,6 +1718,12 @@ def get_optimal_transform(
): ):
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
if torch.isnan(src_atoms).any():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
src_atoms = torch.nan_to_num(src_atoms,nan=0.0)
if mask is not None: if mask is not None:
assert mask.dtype == torch.bool assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2] assert mask.shape[-1] == src_atoms.shape[-2]
...@@ -1711,7 +1735,7 @@ def get_optimal_transform( ...@@ -1711,7 +1735,7 @@ def get_optimal_transform(
tgt_atoms = tgt_atoms[mask, :] tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center) r = kabsch_rotation(src_atoms,tgt_atoms)
tgt_center,src_center = tgt_center.to('cpu'),src_center.to('cpu') # load to cpu memory just in case tgt_center,src_center = tgt_center.to('cpu'),src_center.to('cpu') # load to cpu memory just in case
x = tgt_center - src_center @ r x = tgt_center - src_center @ r
return r, x return r, x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment