Commit 2a1028f0 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update AlphaFoldMultimerLoss to accomodate new way of data_module loading procudure

parent dff973ab
......@@ -1848,8 +1848,6 @@ def greedy_align(
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
if cur_asym_id==0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
cur_entity_ids = batch["entity_id"][asym_mask][0]
......@@ -1878,7 +1876,15 @@ def greedy_align(
return align
def merge_labels(per_asym_residue_index, labels, align):
def pad_features(feature_tensor,nres_pad,pad_dim):
"""Pad input feature tensor"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape,device=feature_tensor.device)
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
......@@ -1905,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed
if new_v.shape[dimension_to_merge]!=original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v,nres_pad,pad_dim=dimension_to_merge)
outs[k] = new_v
return outs
......@@ -2027,9 +2037,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES):
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict):
"""
Splits ground truth features according to chains
......@@ -2044,11 +2060,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=1)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels
@staticmethod
def multi_chain_perm_align(out, batch, permutate_chains=True):
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
......@@ -2056,7 +2073,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
......@@ -2070,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
......@@ -2129,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False, permutate_chains=True):
def forward(self, out, features, _return_breakdown=False):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
......@@ -2138,22 +2155,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# permutate ground truth chains before calculating the loss
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
# permutate_chains=permutate_chains)
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
print(f"########## line 2154 loss.py features is {type(features)}")
for k,v in features.items():
print(f"{k}:{v.shape}")
# permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
import sys
sys.exit()
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment