Commit dff973ab authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update get_least_asym_entity_or_longest_length and added split_ground_truth_labels

parent ab09ded4
......@@ -1770,8 +1770,8 @@ def get_optimal_transform(
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
src_center = src_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
......@@ -1792,6 +1792,12 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES = ['entity_id','asym_id']
seq_length = batch['seq_length'].item()
# remove padding part before selecting candidate
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device))
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES}
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
......@@ -1842,6 +1848,8 @@ def greedy_align(
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
if cur_asym_id==0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
cur_entity_ids = batch["entity_id"][asym_mask][0]
......@@ -2021,7 +2029,26 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
self.config = config
@staticmethod
def multi_chain_perm_align(out, batch, labels, permutate_chains=True):
def split_ground_truth_labels(batch,REQUIRED_FEATURES):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
unique_asym_ids, asym_id_counts= unique_asym_ids.tolist(),asym_id_counts.tolist()
if 0 in unique_asym_ids:
pop_idx = unique_asym_ids.index(0)
padding_asym_id = unique_asym_ids.pop(pop_idx)
padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=1)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels
@staticmethod
def multi_chain_perm_align(out, batch, permutate_chains=True):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
......@@ -2029,6 +2056,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
......@@ -2049,6 +2078,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
......@@ -2074,7 +2104,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
align = greedy_align(
......@@ -2114,7 +2144,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
print(f"########## line 2154 loss.py features is {type(features)}")
for k,v in features.items():
print(f"{k}:{v.shape}")
# permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,
permutate_chains=permutate_chains)
import sys
sys.exit()
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment