Unverified Commit ec75fe22 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #359 from dingquanyu/update-kabsch_rotation

move the kabsch rotation step to gpu
parents 377f854c cdfb0c75
...@@ -37,7 +37,6 @@ from openfold.utils.tensor_utils import ( ...@@ -37,7 +37,6 @@ from openfold.utils.tensor_utils import (
import random import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
import procrustes
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
import gc import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1716,9 +1715,8 @@ def kabsch_rotation(P, Q): ...@@ -1716,9 +1715,8 @@ def kabsch_rotation(P, Q):
Use procrustes package to calculate best rotation that minimises Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q the RMSD betwee P and Q
The optimal rotation matrix was calculated using The optimal rotation matrix was calculated using Kabsch algorithm:
the rotational() function from procrustes package. Details can be found here: https://en.wikipedia.org/wiki/Kabsch_algorithm
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args: Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
...@@ -1727,15 +1725,18 @@ def kabsch_rotation(P, Q): ...@@ -1727,15 +1725,18 @@ def kabsch_rotation(P, Q):
return: return:
A 3*3 rotation matrix A 3*3 rotation matrix
""" """
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().float().numpy(),translate=False,scale=False)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3])
return rotation.to(device=P.device, dtype=P.dtype)
# Firstly, compute SVD of P.T * Q
u,_,vt = torch.linalg.svd(torch.matmul(P.to(torch.float32).T,Q.to(torch.float32)),driver='gesvd')
# Then construct s matrix
s = torch.eye(P.shape[1],device=P.device)
# correct the rotation matrix to ensure a right-handed coordinate
s[-1, -1] = torch.sign(torch.linalg.det(torch.matmul(u, vt)))
# finally compute the rotation matrix
r_opt = torch.matmul(torch.matmul(u, s), vt)
assert r_opt.shape == torch.Size([3,3])
return r_opt.to(device=P.device, dtype=P.dtype)
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment