Commit 14853379 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'multimer' of https://github.com/aqlaboratory/openfold into multimer

parents 0cf1541c ec75fe22
...@@ -31,7 +31,6 @@ from openfold.utils.tensor_utils import ( ...@@ -31,7 +31,6 @@ from openfold.utils.tensor_utils import (
import random import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
import procrustes
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1704,12 +1703,10 @@ def compute_rmsd( ...@@ -1704,12 +1703,10 @@ def compute_rmsd(
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
Use procrustes package to calculate the best rotation that minimises Calculate the best rotation that minimises the RMSD between P and Q.
the RMSD between P and Q
The optimal rotation matrix was calculated using The optimal rotation matrix was calculated using Kabsch algorithm:
the rotational() function from procrustes package. Details can be found here: https://en.wikipedia.org/wiki/Kabsch_algorithm
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args: Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
...@@ -1718,15 +1715,19 @@ def kabsch_rotation(P, Q): ...@@ -1718,15 +1715,19 @@ def kabsch_rotation(P, Q):
return: return:
A 3*3 rotation matrix A 3*3 rotation matrix
""" """
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
rotation = procrustes.rotational(P.detach().cpu().float().numpy(),
Q.detach().cpu().float().numpy(), translate=False, scale=False)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation = torch.tensor(rotation.t, dtype=torch.float)
assert rotation.shape == torch.Size([3, 3])
return rotation.to(device=P.device, dtype=P.dtype)
# Firstly, compute SVD of P.T * Q
u, _, vt = torch.linalg.svd(torch.matmul(P.to(torch.float32).T,
Q.to(torch.float32)))
# Then construct s matrix
s = torch.eye(P.shape[1], device=P.device)
# correct the rotation matrix to ensure a right-handed coordinate
s[-1, -1] = torch.sign(torch.linalg.det(torch.matmul(u, vt)))
# finally compute the rotation matrix
r_opt = torch.matmul(torch.matmul(u, s), vt)
assert r_opt.shape == torch.Size([3,3])
return r_opt.to(device=P.device, dtype=P.dtype)
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment