"docs/guides/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "e0cd84890b47d7fc65cdf5b807f5df78c0009cb1"
Unverified Commit 61cbc7d3 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #343 from dingquanyu/permutation

Update multi-chain permutation
parents 8e522aca e097da95
......@@ -1770,8 +1770,8 @@ def get_optimal_transform(
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
src_center = src_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True,dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
......@@ -1792,6 +1792,12 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES = ['entity_id','asym_id']
seq_length = batch['seq_length'].item()
# remove padding part before selecting candidate
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device))
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES}
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
......@@ -1870,7 +1876,15 @@ def greedy_align(
return align
def merge_labels(per_asym_residue_index, labels, align):
def pad_features(feature_tensor,nres_pad,pad_dim):
"""Pad input feature tensor"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape,device=feature_tensor.device)
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
......@@ -1897,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed
if new_v.shape[dimension_to_merge]!=original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v,nres_pad,pad_dim=dimension_to_merge)
outs[k] = new_v
return outs
......@@ -2019,9 +2037,35 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod
def multi_chain_perm_align(out, batch, labels, permutate_chains=True):
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
unique_asym_ids, asym_id_counts= unique_asym_ids.tolist(),asym_id_counts.tolist()
if 0 in unique_asym_ids:
pop_idx = unique_asym_ids.index(0)
padding_asym_id = unique_asym_ids.pop(pop_idx)
padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels
@staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
......@@ -2029,6 +2073,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
......@@ -2041,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
......@@ -2049,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
......@@ -2074,7 +2121,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
align = greedy_align(
......@@ -2099,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False, permutate_chains=True):
def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
......@@ -2108,12 +2155,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# permutate ground truth chains before calculating the loss
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
# permutate_chains=permutate_chains)
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
# first check if it is a monomer
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
......@@ -2122,4 +2180,4 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
else:
cum_loss, losses = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
return cum_loss, losses
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import unittest
from openfold.utils.loss import AlphaFoldMultimerLoss
from openfold.utils.loss import get_least_asym_entity_or_longest_length,merge_labels,pad_features
from openfold.utils.tensor_utils import tensor_tree_map
import math
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
create fake input structure features
and rotation matrices
"""
theta = math.pi/4
self.rotation_matrix_z = torch.tensor([
[math.cos(theta),-math.sin(theta),0],
[math.sin(theta),math.cos(theta),0],
[0,0,1]
],device='cuda')
self.rotation_matrix_x = torch.tensor([
[1,0,0],
[0,math.cos(theta),-math.sin(theta)],
[0,math.sin(theta),math.cos(theta)],
],device='cuda')
self.rotation_matrix_y = torch.tensor([
[math.cos(theta),0,math.sin(theta)],
[0,1,0],
[-math.sin(theta),1,math.cos(theta)],
],device='cuda')
self.chain_a_num_res=9
self.chain_b_num_res=13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self.residue_index=list(range(self.chain_a_num_res))*2 + list(range(self.chain_b_num_res))*3
self.num_res = self.chain_a_num_res*2 + self.chain_b_num_res*3
self.asym_id = torch.tensor([[1]*self.chain_a_num_res+[2]*self.chain_a_num_res+[3]*self.chain_b_num_res+[4]*self.chain_b_num_res+[5]*self.chain_b_num_res],device='cuda')
self.sym_id = self.asym_id
self.entity_id = torch.tensor([[1]*(self.chain_a_num_res*2)+[2]*(self.chain_b_num_res*3)],device='cuda')
def test_1_selecting_anchors(self):
self.batch = {
'asym_id':self.asym_id,
'sym_id':self.sym_id,
'entity_id':self.entity_id,
'seq_length':torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym=get_least_asym_entity_or_longest_length(self.batch)
self.assertIn(int(anchor_gt_asym),[1,2])
self.assertNotIn(int(anchor_gt_asym),[3,4,5])
self.assertIn(int(anchor_pred_asym),[1,2])
self.assertNotIn(int(anchor_pred_asym),[3,4,5])
def test_2_permutation_pentamer(self):
batch = {
'asym_id':self.asym_id,
'sym_id':self.sym_id,
'entity_id':self.entity_id,
'seq_length':torch.tensor([57]),
'aatype':torch.randint(21,size=(1,57))
}
batch['asym_id'] = batch['asym_id'].reshape(1,self.num_res)
batch["residue_index"] = torch.tensor([self.residue_index],device='cuda')
# create fake ground truth atom positions
chain_a1_pos = torch.randint(15,(self.chain_a_num_res,3*37),
device='cuda',dtype=torch.float).reshape(1,self.chain_a_num_res,37,3)
chain_a2_pos = torch.matmul(chain_a1_pos,self.rotation_matrix_x)+10
chain_b1_pos = torch.randint(low=15,high=30,size=(self.chain_b_num_res,3*37),
device='cuda',dtype=torch.float).reshape(1,self.chain_b_num_res,37,3)
chain_b2_pos = torch.matmul(chain_b1_pos,self.rotation_matrix_y)+10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos,self.rotation_matrix_z),self.rotation_matrix_x)+30
# Below permutate predicted chain positions
pred_atom_position = torch.cat((chain_a2_pos,chain_a1_pos,chain_b2_pos,chain_b3_pos,chain_b1_pos),dim=1)
pred_atom_mask = torch.ones((1,self.num_res,37),device='cuda')
out = {
'final_atom_positions':pred_atom_position,
'final_atom_mask':pred_atom_mask
}
true_atom_position = torch.cat((chain_a1_pos,chain_a2_pos,chain_b1_pos,chain_b2_pos,chain_b3_pos),dim=1)
true_atom_mask = torch.cat((torch.ones((1,self.chain_a_num_res,37),device='cuda'),
torch.ones((1,self.chain_a_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda')),dim=1)
batch['all_atom_positions'] = true_atom_position
batch['all_atom_mask'] = true_atom_mask
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
dim_dict,
permutate_chains=True)
possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]]
wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]]
self.assertIn(aligns,possible_outcome)
self.assertNotIn(aligns,wrong_outcome)
def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
'asym_id':pad_features(self.asym_id,nres_pad,pad_dim=1),
'sym_id':pad_features(self.sym_id,nres_pad,pad_dim=1),
'entity_id':pad_features(self.entity_id,nres_pad,pad_dim=1),
'aatype':torch.randint(21,size=(1,325)),
'seq_length':torch.tensor([57])
}
batch['asym_id'] = batch['asym_id'].reshape(1,325)
batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1,57),nres_pad,pad_dim=1)
# create fake ground truth atom positions
chain_a1_pos = torch.randint(15,(self.chain_a_num_res,3*37),
device='cuda',dtype=torch.float).reshape(1,self.chain_a_num_res,37,3)
chain_a2_pos = torch.matmul(chain_a1_pos,self.rotation_matrix_x)+10
chain_b1_pos = torch.randint(low=15,high=30,size=(self.chain_b_num_res,3*37),
device='cuda',dtype=torch.float).reshape(1,self.chain_b_num_res,37,3)
chain_b2_pos = torch.matmul(chain_b1_pos,self.rotation_matrix_y)+10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos,self.rotation_matrix_z),self.rotation_matrix_x)+30
# Below permutate predicted chain positions
pred_atom_position = torch.cat((chain_a2_pos,chain_a1_pos,chain_b2_pos,chain_b3_pos,chain_b1_pos),dim=1)
pred_atom_mask = torch.ones((1,self.num_res,37),device='cuda')
pred_atom_position = pad_features(pred_atom_position,nres_pad,pad_dim=1)
pred_atom_mask = pad_features(pred_atom_mask,nres_pad,pad_dim=1)
out = {
'final_atom_positions':pred_atom_position,
'final_atom_mask':pred_atom_mask
}
true_atom_position = torch.cat((chain_a1_pos,chain_a2_pos,chain_b1_pos,chain_b2_pos,chain_b3_pos),dim=1)
true_atom_mask = torch.cat((torch.ones((1,self.chain_a_num_res,37),device='cuda'),
torch.ones((1,self.chain_a_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda'),
torch.ones((1,self.chain_b_num_res,37),device='cuda')),dim=1)
batch['all_atom_positions'] = pad_features(true_atom_position,nres_pad,pad_dim=1)
batch['all_atom_mask'] = pad_features(true_atom_mask,nres_pad=nres_pad,pad_dim=1)
tensor_to_cuda = lambda t: t.to('cuda')
batch = tensor_tree_map(tensor_to_cuda,batch)
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
batch,
dim_dict,
permutate_chains=True)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
labels = merge_labels(per_asym_residue_index,labels,aligns,
original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))
expected_permutated_gt_pos = torch.cat((chain_a2_pos,chain_a1_pos,chain_b2_pos,chain_b3_pos,chain_b1_pos),dim=1)
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos,nres_pad,pad_dim=1)
self.assertTrue(torch.equal(labels['all_atom_positions'],expected_permutated_gt_pos))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment