Unverified Commit 0ca66146 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #353 from dingquanyu/permutation

Update multi-chain permutation and training codes
parents 8820875b a9d65037
...@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import ( ...@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
import random import random
logging.basicConfig(level=logging.INFO)
@contextlib.contextmanager @contextlib.contextmanager
def temp_fasta_file(sequence_str): def temp_fasta_file(sequence_str):
...@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
) )
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def idx_to_mmcif_id(self, idx): def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx] return self._mmcifs[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
...@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if (self._output_raw): if (self._output_raw):
return data return data
# process all_chain_features # process all_chain_features
data = self.feature_pipeline.process_features(data, data,ground_truth = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
# if it's inference mode, only need all_chain_features # if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor( data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])], [idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64, dtype=torch.int64,
device=data["aatype"].device) device=data["aatype"].device)
return data return data, ground_truth
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
...@@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset): ...@@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
mmcif_id = dataset.idx_to_mmcif_id(i) mmcif_id = dataset.idx_to_mmcif_id(i)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry, if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9, max_resolution=9):
minimum_number_of_residues=5):
selected_idx.append(i) selected_idx.append(i)
logging.info(f"Originally {len(mmcif_data_cache)} mmcifs. After filtering: {len(selected_idx)}")
else: else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict))) selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx return selected_idx
......
...@@ -81,7 +81,7 @@ def np_example_to_features( ...@@ -81,7 +81,7 @@ def np_example_to_features(
seq_length = np_example["seq_length"] seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length) num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example: if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop( np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int" "deletion_matrix_int"
...@@ -90,15 +90,29 @@ def np_example_to_features( ...@@ -90,15 +90,29 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
if(not is_multimer): if is_multimer:
features = input_pipeline.process_tensors_from_config( if mode == 'train':
tensor_dict, features,gt_features = input_pipeline_multimer.process_tensors_from_config(
cfg.common, tensor_dict,
cfg[mode], cfg.common,
) cfg[mode],
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=False
)
return {k: v for k, v in features.items()}
else: else:
features = input_pipeline_multimer.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
cfg[mode], cfg[mode],
......
...@@ -21,19 +21,8 @@ from openfold.data import ( ...@@ -21,19 +21,8 @@ from openfold.data import (
data_transforms_multimer, data_transforms_multimer,
) )
def grountruth_transforms_fns():
def nonensembled_transform_fns(common_cfg, mode_cfg): transforms = [data_transforms.make_atom14_masks,
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
]
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions, data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames, data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""), data_transforms.atom37_to_torsion_angles(""),
...@@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.get_backbone_frames, data_transforms.get_backbone_frames,
data_transforms.get_chi_angles, data_transforms.get_chi_angles,
] ]
) return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms return transforms
...@@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
def process_tensors_from_config(tensors, common_cfg, mode_cfg): GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
if is_training:
gt_tensors= prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
...@@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i d["ensemble_index"] = i
return fn(d) return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn( tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
) )
return tensors if is_training:
return tensors,gt_tensors
else:
return tensors
@data_transforms.curry1 @data_transforms.curry1
def compose(x, fs): def compose(x, fs):
......
...@@ -1700,9 +1700,6 @@ def compute_rmsd( ...@@ -1700,9 +1700,6 @@ def compute_rmsd(
atom_mask: torch.Tensor = None, atom_mask: torch.Tensor = None,
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos
pred_atom_pos = pred_atom_pos
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos del true_atom_pos
del pred_atom_pos del pred_atom_pos
...@@ -1784,20 +1781,23 @@ def get_optimal_transform( ...@@ -1784,20 +1781,23 @@ def get_optimal_transform(
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch): def get_least_asym_entity_or_longest_length(batch,input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES = ['entity_id','asym_id']
seq_length = batch['seq_length'].item()
# remove padding part before selecting candidate Args:
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device)) batch: in this funtion batch is the full ground truth features
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES} input_asym_id: A list of aym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {} entity_asym_count = {}
entity_length = {} entity_length = {}
...@@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch):
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities) least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1 assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]]) least_asym_entities = least_asym_entities[0]
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align( def greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -1847,6 +1843,7 @@ def greedy_align( ...@@ -1847,6 +1843,7 @@ def greedy_align(
""" """
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
...@@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim): ...@@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim) return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres): def merge_labels(per_asym_residue_index,labels, align,original_nres):
""" """
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex. Merge ground truth labels according to the permutation results
labels: list of original ground truth feats labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym. align: list of tuples, each entry specify the corresponding label of the asym.
...@@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres): ...@@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
cur_out = {} cur_out = {}
for i, j in align: for i, j in align:
label = labels[j][k] label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k : if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue continue
else: else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0 dimension_to_merge = 1
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index) cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
...@@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config): def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config) super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod @staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict): def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1):
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation
""" """
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
...@@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
unique_asym_ids.append(padding_asym_id) unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts) asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES]))) labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels return labels
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True): def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
return per_asym_residue_index
@staticmethod
def get_entity_2_asym_list(batch):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue,
asym_mask,pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,anchor_gt_residue,
asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
mask=input_mask
)
return r, x
@staticmethod
def multi_chain_perm_align(out, batch,permutate_chains=False):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict, feature, ground_truth = batch
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) del batch
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) best_rmsd = float('inf')
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}") best_align = None
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors
unique_entity_ids = torch.unique(batch["entity_id"]) ca_idx = rc.atom_order["CA"]
entity_2_asym_list = {} pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
for cur_ent_id in unique_entity_ids: pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) true_ca_poses = [
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id l["all_atom_positions"][..., ca_idx, :] for l in labels
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() ] # list([nres, 3])
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] true_ca_masks = [
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx) l["all_atom_mask"][..., ca_idx].long() for l in labels
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]] ] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx) for candidate_pred_anchor in anchor_pred_asym_ids:
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]] asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)]
input_mask = (anchor_true_mask * anchor_pred_mask).bool() r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
r, x = get_optimal_transform( anchor_gt_idx,anchor_gt_residue,
anchor_pred_pos, anchor_true_pos[0], true_ca_masks,pred_ca_mask,
mask=input_mask[0] asym_mask,
) pred_ca_pos
del input_mask # just to save memory )
del anchor_pred_mask aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del anchor_true_mask align = greedy_align(
gc.collect() feature,
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms per_asym_residue_index,
del true_ca_poses entity_2_asym_list,
gc.collect() pred_ca_pos,
align = greedy_align( pred_ca_mask,
batch, aligned_true_ca_poses,
per_asym_residue_index, true_ca_masks,
unique_asym_ids, )
entity_2_asym_list, merged_labels = merge_labels(per_asym_residue_index,labels,align,
pred_ca_pos, original_nres=feature['aatype'].shape[-1])
pred_ca_mask, rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
aligned_true_ca_poses, pred_atom_pos = pred_ca_pos,
true_ca_masks, atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
) if rmsd < best_rmsd:
best_rmsd = rmsd
del aligned_true_ca_poses, true_ca_masks best_align = align
del r, x del r,x
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else: else:
align = list(enumerate(range(len(labels)))) per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
return align, per_asym_residue_index REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
best_align = list(enumerate(range(len(labels))))
return best_align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False,permutate_chains=True): def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first check if it is a monomer # first check if it is a monomer
features, ground_truth = batch
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1] is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer: if not is_monomer:
permutate_chains = True permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss # Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict, align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
permutate_chains=permutate_chains) (features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in ground_truth.keys()])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align, # reorder ground truth labels according to permutation results
original_nres=features['aatype'].shape[-1]) labels = merge_labels(per_asym_residue_index,labels,align,
features.update(labels) original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
......
...@@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase): ...@@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
batch['all_atom_mask'] = true_atom_mask batch['all_atom_mask'] = true_atom_mask
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch, aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
dim_dict, dim_dict,
permutate_chains=True) permutate_chains=True)
print(f"##### aligns is {aligns}")
possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]] possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]]
wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]] wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]]
self.assertIn(aligns,possible_outcome) self.assertIn(aligns,possible_outcome)
...@@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase): ...@@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda = lambda t: t.to('cuda') tensor_to_cuda = lambda t: t.to('cuda')
batch = tensor_tree_map(tensor_to_cuda,batch) batch = tensor_tree_map(tensor_to_cuda,batch)
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
batch, batch,
dim_dict, dim_dict,
permutate_chains=True) permutate_chains=True)
print(f"##### aligns is {aligns}")
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
labels = merge_labels(per_asym_residue_index,labels,aligns, labels = merge_labels(labels,aligns,
original_nres=batch['aatype'].shape[-1]) original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index'])) self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))
......
...@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch) return self.model(batch)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it # Log it
if(self.ema.device != batch["aatype"].device): if(self.ema.device != features["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(features["aatype"].device)
# Run the model # Run the model
outputs = self(batch) outputs = self(features)
# Remove the recycling dimension # Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch) features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss # Compute loss
loss, loss_breakdown = self.loss( loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, (features,gt_features), _return_breakdown=True
) )
# Log it # Log it
self._log(loss_breakdown, batch, outputs) self._log(loss_breakdown, features, outputs)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
...@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model # Run the model
outputs = self(batch) outputs = self(features)
# Compute loss and other metrics # Compute loss and other metrics
batch["use_clamped_fape"] = 0. features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss( _, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, (features,gt_features), _return_breakdown=True
) )
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment