Commit e3e8a681 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished working on selecting best anchors. now start working on get_optimal_transform

parent 1008f61d
...@@ -23,16 +23,14 @@ from openfold.data import data_transforms ...@@ -23,16 +23,14 @@ from openfold.data import data_transforms
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts from tests.config import consts
from .unifold_permutation import multi_chain_perm_align
import logging
logger = logging.getLogger(__name__)
import os
from tests.data_utils import ( from tests.data_utils import (
random_template_feats, random_template_feats,
random_extra_msa_feats, random_extra_msa_feats,
) )
from tests.data_utils import load_labels
from openfold.data.data_transforms import make_msa_feat
import logging
logger = logging.getLogger(__name__)
import os
class TestPermutation(unittest.TestCase): class TestPermutation(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -41,50 +39,67 @@ class TestPermutation(unittest.TestCase): ...@@ -41,50 +39,67 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label In the test case, use PDB ID 1e4k as the label
""" """
self.multimer_feature_path=os.path.join(os.getcwd(),"tests/test_data/example_multimer_processed_feature.pkl") self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_dir = os.path.join(os.getcwd(),"tests/test_data") self.label_ids = ['label_1','label_2']
def test_dry_run(self): def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True) c = model_config(consts.model, train=True)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c)
label_ids = ["1e4k_A","1e4k_B","1e4k_C"] example_label = [pickle.load(open(os.path.join(self.test_data_dir,f"{i}.pkl"),'rb'))
sequence_ids = ["P01857","P01857","O75015"] for i in self.label_ids]
features = pickle.load(open(self.multimer_feature_path,"rb")) batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
print(f"target_feat shape is {batch['target_feat'].size()}")
print(f"batch_dim is {batch['target_feat'].shape[:-2]}")
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
# if consts.is_multimer:
# I suppose between_segment_residues are always 0 ? #
# # # Modify asym_id, entity_id and sym_id so that it encodes
num_res = features['aatype'].shape[0] # 2 chains
protein = {'between_segment_residues': torch.tensor([0]*num_res,dtype=torch.int32), # #
'msa': torch.tensor(features['msa'], dtype=torch.int64), asym_id = [1]*9 + [2]*13
'deletion_matrix': torch.tensor(features['deletion_matrix']), batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
'aatype': torch.tensor(features['aatype'],dtype=torch.int64)} # batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
protein = make_msa_feat.__wrapped__(protein) batch['entity_id'] = torch.tensor(asym_id,dtype=torch.float64)
print(f"protein now is {type(protein)}") batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
for k,v in protein.items(): batch["num_sym"] = torch.tensor([2]*22,dtype=torch.int64) # currently there are just 2 chains
print(f"{k},{v.size()}") batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
# if consts.is_multimer: add_recycling_dims = lambda t: (
# # t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
# # Modify asym_id, entity_id and sym_id so that it encodes )
# # 2 chains print(f"max_recycling_iters is {c.data.common.max_recycling_iters}")
# # # batch = tensor_tree_map(add_recycling_dims, batch)
# asym_id = [1]*11 + [2]*11
# batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
# batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
# add_recycling_dims = lambda t: (
# t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
# )
# print(f"max_recycling_iters is {c.data.common.max_recycling_iters}")
# batch = tensor_tree_map(add_recycling_dims, batch)
# with torch.no_grad(): with torch.no_grad():
# out = model(batch) out = model(batch)
# print("finished running multimer forward") print("finished running multimer forward")
# print(f"out is {type(out)} and has keys {out.keys()}") print(f"out is {type(out)} and has keys {out.keys()}")
# print(f"final_atom_positions is {out['final_atom_positions'].shape}") print(f"final_atom_positions is {out['final_atom_positions'].shape}")
\ No newline at end of file print(f"out itpm score is {out['iptm_score']}")
multi_chain_perm_align(out,batch,example_label)
\ No newline at end of file
import torch
from openfold.np import residue_constants as rc
import logging
logger = logging.getLogger(__name__)
import sys
def kabsch_rotation(P, Q):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
C = P.transpose(-1, -2) @ Q
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, _, W = torch.linalg.svd(C)
d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0
if d:
V[:, -1] = -V[:, -1]
# Create Rotation matrix U
U = V @ W
return U
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor = None,
):
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
if mask is not None:
assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center)
x = tgt_center - src_center @ r
return r, x
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None:
sq_diff = sq_diff[atom_mask]
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)
def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].float() # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :].float() for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
asym_mask = asym_mask[:,0] # somehow need to adjust the asym_mask shape
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch, per_asym_residue_index, true_ca_masks
)
print(f"anchor_gt_asym is {anchor_gt_asym}, anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e9
best_labels = None
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
print(f"anchor_residue_idx:{anchor_residue_idx},anchor_gt_idx:{anchor_gt_idx}\n")
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
asym_mask = asym_mask[:,0] # somehow need to adjust the asym_mask shape
anchor_pred_pos = pred_ca_pos[asym_mask]
print(f"true_ca_masks:\n")
print(true_ca_masks[anchor_gt_idx].bool())
print(f"pred_ca_mask\n")
print(pred_ca_mask.bool())
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx[:,0]]
anchor_pred_mask = pred_ca_mask[asym_mask]
print(f"anchor_true_mask:\n")
print(anchor_true_mask.shape)
print(f"anchor_pred_mask:\n")
print(anchor_pred_mask.shape)
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
)
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align(
batch,
per_asym_residue_index,
shuffled_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :] @ r + x,
pred_ca_pos,
(pred_ca_mask * merged_labels["all_atom_mask"][..., ca_idx]).bool(),
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
return best_labels
def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
def find_by_num_sym(min_num_sym):
best_len = -1
best_gt_asym = None
asym_ids = torch.unique(batch["asym_id"][batch["num_sym"] == min_num_sym])
for cur_asym_id in asym_ids:
assert cur_asym_id > 0
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
j = int(cur_asym_id - 2) # somehow have to change from -1 to -2
cur_true_mask = true_masks[j][cur_residue_index]
cur_len = cur_true_mask.sum()
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_asym_id
print(f"finished selected the best anchor\n best_gt_asym is {best_gt_asym} and best_len is {best_len}")
return best_gt_asym, best_len
sorted_num_sym = batch["num_sym"][batch["num_sym"] > 0].sort()[0]
best_gt_asym = None
best_len = -1
for cur_num_sym in sorted_num_sym:
if cur_num_sym <= 0:
continue
cur_gt_sym, cur_len = find_by_num_sym(cur_num_sym)
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_gt_sym
if best_len >= 3:
break
best_entity = batch["entity_id"][batch["asym_id"] == best_gt_asym][0]
print(f"best_entity is {best_entity}\n")
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == best_entity])
return best_gt_asym, best_pred_asym
def greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
true_ca_poses,
true_ca_masks,
):
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
# skip padding
if cur_asym_id == 0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
num_sym = batch["num_sym"][asym_mask][0]
# don't need to align
if (num_sym) == 1:
align.append((i, i))
assert used[i] == False
used[i] = True
continue
cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = 1e10
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # posesible candidate
cropped_pos = true_ca_poses[j][cur_residue_index]
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask * mask).bool()
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
return align
def merge_labels(batch, per_asym_residue_index, labels, align):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
"""
num_res = batch["msa_mask"].shape[-1]
outs = {}
for k, v in labels[0].items():
if k in [
"resolution",
]:
continue
cur_out = {}
for i, j in align:
label = labels[j][k]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
cur_out[i] = label[cur_residue_index]
cur_out = [x[1] for x in sorted(cur_out.items())]
new_v = torch.concat(cur_out, dim=0)
merged_nres = new_v.shape[0]
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v
return outs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment