Unverified Commit 7d227395 authored by Jannik Gut's avatar Jannik Gut Committed by GitHub
Browse files

Merge branch 'main' into main

parents b38b6078 f37d0d96
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
>6KWC_1
GSTIQPGTGYNNGYFYSYWNDGHGGVTYTNGPGGQFSVNWSNSGEFVGGKGWQPGTKNKVINFSGSYNPNGNSYLSVYGWSRNPLIEYYIVENFGTYNPSTGATKLGEVTSDGSVYDIYRTQRVNQPSIIGTATFYQYWSVRRNHRSSGSVNTANHFNAWAQQGLTLGTMDYQIVAVQGYFSSGSASITVS
#!/bin/bash
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export FASTA_DIR=./fasta_dir
export OUTPUT_DIR=./
export PRECOMPUTED_ALIGNMENT_DIR=./alignments
export MMCIF_DIR=/mmcifs # UPDATE with path to your mmcifs directory
python3 run_pretrained_openfold.py $FASTA_DIR \
$MMCIF_DIR \
--output_dir $OUTPUT_DIR \
--config_preset model_1_ptm \
--model_device "cuda:0" \
--data_random_seed 42 \
--use_precomputed_alignments $PRECOMPUTED_ALIGNMENT_DIR
This diff is collapsed.
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -120,7 +111,7 @@ ...@@ -120,7 +111,7 @@
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.79\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
...@@ -136,7 +127,7 @@ ...@@ -136,7 +127,7 @@
"\n", "\n",
" %shell mkdir -p /content/openfold/openfold/resources\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
"\n", "\n",
" commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n", " commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
"\n", "\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
...@@ -259,7 +250,7 @@ ...@@ -259,7 +250,7 @@
"from openfold.np import protein\n", "from openfold.np import protein\n",
"from openfold.np.relax import relax\n", "from openfold.np.relax import relax\n",
"from openfold.np.relax.utils import overwrite_b_factors\n", "from openfold.np.relax.utils import overwrite_b_factors\n",
"from openfold.utils.import_weights import import_jax_weights_\n", "from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n", "from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n", "\n",
"from IPython import display\n", "from IPython import display\n",
...@@ -582,7 +573,7 @@ ...@@ -582,7 +573,7 @@
" model_name,\n", " model_name,\n",
" )\n", " )\n",
" d = torch.load(params_name)\n", " d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n", " import_openfold_weights_(model=openfold_model, state_dict=d)\n",
" else:\n", " else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n", " raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n", "\n",
......
...@@ -62,7 +62,8 @@ def model_config( ...@@ -62,7 +62,8 @@ def model_config(
name, name,
train=False, train=False,
low_prec=False, low_prec=False,
long_sequence_inference=False long_sequence_inference=False,
use_deepspeed_evoformer_attention=False,
): ):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS # TRAINING PRESETS
...@@ -237,6 +238,9 @@ def model_config( ...@@ -237,6 +238,9 @@ def model_config(
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False c.model.evoformer_stack.tune_chunk_size = False
if use_deepspeed_evoformer_attention:
c.globals.use_deepspeed_evo_attention = True
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
......
...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self, stage=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage=None):
generator = None generator = None
if self.batch_seed is not None: if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
if self.eval_dataset is not None: if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None return []
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self): def setup(self, setup=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset, dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
......
...@@ -257,7 +257,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -257,7 +257,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features return features
...@@ -603,7 +603,7 @@ def convert_monomer_features( ...@@ -603,7 +603,7 @@ def convert_monomer_features(
) -> FeatureDict: ) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models.""" """Reshapes and modifies monomer features for multimer models."""
converted = {} converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = { unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length' 'sequence', 'domain_name', 'num_alignments', 'seq_length'
} }
...@@ -1309,7 +1309,7 @@ class DataPipelineMultimer: ...@@ -1309,7 +1309,7 @@ class DataPipelineMultimer:
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
...@@ -24,7 +24,7 @@ import os ...@@ -24,7 +24,7 @@ import os
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
from Bio import PDB from Bio import PDB
from Bio.Data import SCOPData from Bio.Data import PDBData
import numpy as np import numpy as np
from openfold.data.errors import MultipleChainsError from openfold.data.errors import MultipleChainsError
...@@ -283,7 +283,7 @@ def parse( ...@@ -283,7 +283,7 @@ def parse(
author_chain = mmcif_to_author_chain_id[chain_id] author_chain = mmcif_to_author_chain_id[chain_id]
seq = [] seq = []
for monomer in seq_info: for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, "X") code = PDBData.protein_letters_3to1_extended.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X") seq.append(code if len(code) == 1 else "X")
seq = "".join(seq) seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq author_chain_to_sequence[author_chain] = seq
...@@ -347,6 +347,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -347,6 +347,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
try: try:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution) header["resolution"] = float(raw_resolution)
break
except ValueError: except ValueError:
logging.debug( logging.debug(
"Invalid resolution format: %s", parsed_info[res_key] "Invalid resolution format: %s", parsed_info[res_key]
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import re import re
import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
} }
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys()))) convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
converted_state_dict = {} converted_state_dict = {}
for key, value in state_dict.items(): for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary # For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
# Add prefix for template modules # Add prefix for template layers
if new_key.startswith('template'): template_match = re.match(template_emb_re, new_key)
new_key = f'template_embedder.{new_key}' if template_match:
prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
import logging import logging
import random import random
import torch import torch
from typing import Tuple, List, Dict
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -13,6 +13,17 @@ def compute_rmsd( ...@@ -13,6 +13,17 @@ def compute_rmsd(
atom_mask: torch.Tensor = None, atom_mask: torch.Tensor = None,
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Function to calculate RMSD between predicted and ground truth atom position
Args:
true_atom_pos: a [nres*3] tensor
pred_atom_pos: a [nres*3] tensor
atom_mask: a [1*nres] tensor
Return:
RMSD value between true and predicted atom positions
"""
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None: if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device)) sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
...@@ -21,7 +32,7 @@ def compute_rmsd( ...@@ -21,7 +32,7 @@ def compute_rmsd(
return torch.sqrt(msd + eps) # prevent sqrt 0 return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q): def kabsch_rotation(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
""" """
Calculate the best rotation that minimises the RMSD between P and Q. Calculate the best rotation that minimises the RMSD between P and Q.
...@@ -33,7 +44,7 @@ def kabsch_rotation(P, Q): ...@@ -33,7 +44,7 @@ def kabsch_rotation(P, Q):
Q: [N * 3] the same dimension as P Q: [N * 3] the same dimension as P
return: return:
A 3*3 rotation matrix one 3*3 rotation matrix that best aligns the sorce and target atoms
""" """
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
...@@ -54,11 +65,20 @@ def get_optimal_transform( ...@@ -54,11 +65,20 @@ def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor, tgt_atoms: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
A function that obtain the transformation that optimally align
src_atoms with tgt_atoms
Args:
src_atoms: predicted CA positions, shape:[num_res,3] src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3] tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res] mask: a vector of boolean values, shape:[num_res]
Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
""" """
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
...@@ -88,7 +108,7 @@ def get_optimal_transform( ...@@ -88,7 +108,7 @@ def get_optimal_transform(
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch: dict, input_asym_id: list) -> Tuple[torch.Tensor, List[torch.Tensor]]:
""" """
First check how many subunit(s) one sequence has. Select the subunit that is less First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
...@@ -105,7 +125,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -105,7 +125,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
""" """
entity_2_asym_list = get_entity_2_asym_list(batch) entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding
entity_asym_count = {} entity_asym_count = {}
entity_length = {} entity_length = {}
...@@ -145,19 +165,38 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -145,19 +165,38 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
def greedy_align( def greedy_align(
batch, batch: dict,
per_asym_residue_index, per_asym_residue_index: dict,
entity_2_asym_list, entity_2_asym_list: dict,
pred_ca_pos, pred_ca_pos: torch.Tensor,
pred_ca_mask, pred_ca_mask: torch.Tensor,
true_ca_poses, true_ca_poses: list,
true_ca_masks, true_ca_masks: list
): ) -> List[Tuple[int, int]]:
""" """
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Args:
batch: a dictionary of ground truth features
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward()
pred_ca_mask: a boolean tensor that masks pred_ca_pos
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
""" """
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))] # a list the keeps recording whether a ground truth chain has been used or not
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
...@@ -189,21 +228,38 @@ def greedy_align( ...@@ -189,21 +228,38 @@ def greedy_align(
return align return align
def pad_features(feature_tensor, nres_pad, pad_dim): def pad_features(feature_tensor: torch.Tensor, nres_pad: int, pad_dim: int) -> torch.Tensor:
"""Pad input feature tensor""" """
Pad input feature tensor. Padding values will be 0 and put behind the true feature values
Args:
feature_tensor: A feature tensor
nres_pad: number of residues to add
pad_dim: along which dimension of the feature_tensor to pad
Returns:
a padded feature tensor
"""
pad_shape = list(feature_tensor.shape) pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device) padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim) return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align, original_nres): def merge_labels(per_asym_residue_index: Dict[int,List[int]],
labels: List[Dict], align: List[Tuple[int, int]],
original_nres: int) -> Dict[str, torch.Tensor]:
""" """
Merge ground truth labels according to the permutation results Merge ground truth labels according to the permutation results
labels: list of original ground truth feats Args:
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
labels: list of original ground truth feats e.g. if there're 5 chains, labels will have a length of 5
align: list of tuples, each entry specify the corresponding label of the asym. align: list of tuples, each entry specify the corresponding label of the asym.
original_nres: int, corresponding to the number of residues specified by crop_size in config.py
Returns:
A new dictionary of permuated ground truth features
modified based on UniFold: modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1 https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
""" """
...@@ -230,13 +286,20 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres): ...@@ -230,13 +286,20 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres):
return outs return outs
def split_ground_truth_labels(gt_features): def split_ground_truth_labels(gt_features: dict) -> List[Dict]:
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
Args:
gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method
In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline,
thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id)
2) split the concatenated tensors back to individual ones that correspond to individual asym_ids
Returns: Returns:
a list of feature dictionaries with only necessary ground truth features a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there
are 5 chains in total.
""" """
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1] n_res = gt_features["asym_id"].shape[-1]
...@@ -251,7 +314,16 @@ def split_ground_truth_labels(gt_features): ...@@ -251,7 +314,16 @@ def split_ground_truth_labels(gt_features):
return labels return labels
def get_per_asym_residue_index(features): def get_per_asym_residue_index(features: dict) -> Dict[int, torch.Tensor]:
"""
A function that retrieve which residues belong to which asym_id
Args:
features: a dictionary that contains input features after cropping
Returns:
A dictionary that records which region of the sequence belongs to which asym_id
"""
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0] unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
...@@ -261,34 +333,36 @@ def get_per_asym_residue_index(features): ...@@ -261,34 +333,36 @@ def get_per_asym_residue_index(features):
return per_asym_residue_index return per_asym_residue_index
def get_entity_2_asym_list(batch): def get_entity_2_asym_list(features: dict) -> Dict[int, list]:
""" """
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args: Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors. features (dict): A dictionary containing data features, including "entity_id" and "asym_id" tensors.
Returns: Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity. associated with each entity.
""" """
entity_2_asym_list = {} entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(features["entity_id"])
for cur_ent_id in unique_entity_ids: for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id ent_mask = features["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) cur_asym_id = torch.unique(features["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list return entity_2_asym_list
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor,
asym_mask, pred_ca_mask): anchor_gt_residue: torch.Tensor,
asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor:
""" """
Calculate an input mask for downstream optimal transformation computation Calculate an input mask for downstream optimal transformation computation
Args: Args:
true_ca_masks (Tensor): ca mask from ground truth. true_ca_masks: list of masks from ground truth chains.
anchor_gt_idx (Tensor): The index of selected ground truth anchor. anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor. asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure. pred_ca_mask (Tensor): ca mask from predicted structure.
...@@ -303,11 +377,38 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, ...@@ -303,11 +377,38 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
return input_mask return input_mask
def calculate_optimal_transform(true_ca_poses, def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
anchor_gt_idx, anchor_gt_residue, anchor_gt_idx: int, anchor_gt_residue: torch.Tensor,
true_ca_masks, pred_ca_mask, true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor,
asym_mask, asym_mask: torch.Tensor,
pred_ca_pos): pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Takes selected anchor ground truth c-alpha positions and
selected predicted anchor c-alpha position then calculate the optimal rotation matrix
to align ground-truth anchor and predicted anchor
Args:
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions
Process:
1) select an achor chain from ground truth, denoted by anchor_gt_idx, and
an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence
2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue
3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt,
done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm
Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
"""
input_mask = calculate_input_mask(true_ca_masks, input_mask = calculate_input_mask(true_ca_masks,
anchor_gt_idx, anchor_gt_idx,
anchor_gt_residue, anchor_gt_residue,
...@@ -326,13 +427,27 @@ def calculate_optimal_transform(true_ca_poses, ...@@ -326,13 +427,27 @@ def calculate_optimal_transform(true_ca_poses,
return r, x return r, x
def compute_permutation_alignment(out, features, ground_truth): def compute_permutation_alignment(out: Dict[str,torch.Tensor],
features: Dict[str,torch.Tensor],
ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]:
""" """
A class method that first permutate chains in ground truth first A method that permutes chains in ground truth before calculating the loss
before calculating the loss. because the mapping between the predicted and ground-truth will become arbitrary.
The model cannot be assumed to predict chains in the same order as the ground truth.
Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth,
by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
Args:
out: a dictionary of output tensors from model.forward()
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated
a dictionary recording which residues belong to which aysm_id
""" """
unique_asym_ids = set(torch.unique(features['asym_id']).tolist()) unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id unique_asym_ids.discard(0) # Remove padding asym_id
...@@ -397,13 +512,19 @@ def compute_permutation_alignment(out, features, ground_truth): ...@@ -397,13 +512,19 @@ def compute_permutation_alignment(out, features, ground_truth):
return best_align, per_asym_residue_index return best_align, per_asym_residue_index
def multi_chain_permutation_align(out, features, ground_truth): def multi_chain_permutation_align(out: Dict[str, torch.Tensor],
"""Compute multi-chain permutation alignment. features: Dict[str, torch.Tensor],
ground_truth: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
Compute multi-chain permutation alignment.
Args: Args:
out: The output of model.forward() out: a dictionary of output tensors from model.forward()
features: Input features features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: Ground truth features ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.
""" """
labels = split_ground_truth_labels(ground_truth) labels = split_ground_truth_labels(ground_truth)
......
import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from openfold.utils.suppress_output import SuppressLogging
def seed_globally(seed=None):
if("PL_GLOBAL_SEED" not in os.environ):
if(seed is None):
seed = random.randint(0, np.iinfo(np.uint32).max)
os.environ["PL_GLOBAL_SEED"] = str(seed)
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
# seed_everything is a bit log-happy
with SuppressLogging(logging.INFO):
seed_everything(seed=None)
...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords): ...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords): def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy() reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().cpu().numpy() coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np) superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd) return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
......
import logging
import sys
class SuppressStdout:
def __enter__(self):
self.stdout = sys.stdout
dev_null = open("/dev/null", "w")
sys.stdout = dev_null
def __exit__(self, typ, value, traceback):
fp = sys.stdout
sys.stdout = self.stdout
fp.close()
class SuppressLogging:
def __init__(self, level):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, typ, value, traceback):
logging.disable(logging.NOTSET)
...@@ -114,8 +114,7 @@ def tree_map(fn, tree, leaf_type): ...@@ -114,8 +114,7 @@ def tree_map(fn, tree, leaf_type):
elif isinstance(tree, leaf_type): elif isinstance(tree, leaf_type):
return fn(tree) return fn(tree)
else: else:
print(type(tree)) raise ValueError(f"Tree of type {type(tree)} not supported")
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment