Commit 2a1028f0 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update AlphaFoldMultimerLoss to accomodate new way of data_module loading procudure

parent dff973ab
...@@ -1848,8 +1848,6 @@ def greedy_align( ...@@ -1848,8 +1848,6 @@ def greedy_align(
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
if cur_asym_id==0:
continue
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
cur_entity_ids = batch["entity_id"][asym_mask][0] cur_entity_ids = batch["entity_id"][asym_mask][0]
...@@ -1878,7 +1876,15 @@ def greedy_align( ...@@ -1878,7 +1876,15 @@ def greedy_align(
return align return align
def merge_labels(per_asym_residue_index, labels, align): def pad_features(feature_tensor,nres_pad,pad_dim):
"""Pad input feature tensor"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape,device=feature_tensor.device)
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres):
""" """
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex. per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats labels: list of original ground truth feats
...@@ -1905,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align): ...@@ -1905,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge) new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed
if new_v.shape[dimension_to_merge]!=original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v,nres_pad,pad_dim=dimension_to_merge)
outs[k] = new_v outs[k] = new_v
return outs return outs
...@@ -2027,9 +2037,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2027,9 +2037,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config): def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config) super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod @staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES): def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict):
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
...@@ -2044,11 +2060,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2044,11 +2060,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
padding_asym_counts = asym_id_counts.pop(pop_idx) padding_asym_counts = asym_id_counts.pop(pop_idx)
unique_asym_ids.append(padding_asym_id) unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts) asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=1)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels return labels
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, permutate_chains=True): def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2056,7 +2073,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2056,7 +2073,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list) assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
...@@ -2070,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2070,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"]) unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
...@@ -2129,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2129,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False, permutate_chains=True): def forward(self, out, features, _return_breakdown=False):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2138,22 +2155,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2138,22 +2155,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# permutate ground truth chains before calculating the loss # first determin which dimension in the tensor to split into individual ground truth labels
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels, dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# permutate_chains=permutate_chains)
# permutated_labels = merge_labels(per_asym_residue_index, labels, align) # Then permutate ground truth chains before calculating the loss
# permutated_labels.pop('aatype') align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
# features.update(permutated_labels)
print(f"########## line 2154 loss.py features is {type(features)}")
for k,v in features.items():
print(f"{k}:{v.shape}")
# permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,
permutate_chains=permutate_chains) permutate_chains=permutate_chains)
import sys
sys.exit() labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
# permutated_labels = merge_labels(per_asym_residue_index, labels, align) REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss}") print(f"cum_loss: {cum_loss}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment