Commit 7f2a3267 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update kabsch rotation calculation to avoid svd not converge error

parent 5348936a
...@@ -1710,36 +1710,30 @@ def kabsch_rotation(P, Q): ...@@ -1710,36 +1710,30 @@ def kabsch_rotation(P, Q):
""" """
assert P.shape == torch.Size([Q.shape[0],Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0],Q.shape[1]])
finished_rotation = False rotation = procrustes.rotational(P.detach().cpu().numpy(),
while not finished_rotation: Q.detach().cpu().numpy(),translate=False,scale=False)
# rotation = torch.tensor(rotation.t,dtype=torch.float) # rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
# Add a try-except block cuz sometimes SVD fails to converge and crashes the programme
# Will continue trying SVD until the optimal rotaion is calculated
# #
try:
# first need to load P and Q to cpu otherwise cannot extract the numpy matrices
rotation = procrustes.rotational(P.to('cpu').numpy(),
Q.to('cpu').numpy(),translate=True)
finished_rotation = True
except:
print(f"svd failed.")
import sys
sys.exit()
rotation = torch.tensor(rotation.t,dtype=torch.float)
assert rotation.shape == torch.Size([3,3]) assert rotation.shape == torch.Size([3,3])
return rotation return rotation.to('cuda')
def get_optimal_transform( def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor, tgt_atoms: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
): ):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
if torch.isnan(src_atoms).any(): assert len(mask.shape) ==1,"mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
# #
# sometimes using fake test inputs generates NaN in the predicted atom positions # sometimes using fake test inputs generates NaN in the predicted atom positions
# # # #
logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0) src_atoms = torch.nan_to_num(src_atoms,nan=0.0,posinf=1.0,neginf=1.0)
if mask is not None: if mask is not None:
...@@ -1749,8 +1743,8 @@ def get_optimal_transform( ...@@ -1749,8 +1743,8 @@ def get_optimal_transform(
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float() src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms tgt_atoms = src_atoms
else: else:
src_atoms = src_atoms.to('cuda:0')[mask, :] src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms.to('cuda:0')[mask, :] tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms,tgt_atoms) r = kabsch_rotation(src_atoms,tgt_atoms)
...@@ -2069,8 +2063,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2069,8 +2063,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# anchor_pred_mask = anchor_pred_mask.to('cuda') # anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool() input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos[0], anchor_pred_pos,anchor_true_pos[0],
anchor_pred_pos,mask=input_mask mask=input_mask[0]
) )
del input_mask # just to save memory del input_mask # just to save memory
del anchor_pred_mask del anchor_pred_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment