Commit 0be2b30b authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Add code for AlphaFold-Multimer.

PiperOrigin-RevId: 407076987
parent 1d43aaff
......@@ -7,10 +7,17 @@ v2.0. This is a completely new model that was entered in CASP14 and published in
Nature. For simplicity, we refer to this model as AlphaFold throughout the rest
of this document.
Any publication that discloses findings arising from using this source code or
the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2). Please also refer
to the
We also provide an implementation of AlphaFold-Multimer. This represents a work
in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer
AlphaFold system.
[Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers)
for how to upgrade and update code.
Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if
applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).
Please also refer to the
[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)
for a detailed description of the method.
......@@ -45,18 +52,25 @@ The following steps are required in order to run AlphaFold:
or take a look at the following
[NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).
If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the
third party Singularity setups as linked in
https://github.com/deepmind/alphafold/issues/10 or
https://github.com/deepmind/alphafold/issues/24.
### Genetic databases
This step requires `aria2c` to be installed on your machine.
AlphaFold needs multiple genetic (sequence) databases to run:
* [UniRef90](https://www.uniprot.org/help/uniref),
* [MGnify](https://www.ebi.ac.uk/metagenomics/),
* [BFD](https://bfd.mmseqs.com/),
* [Uniclust30](https://uniclust.mmseqs.com/),
* [MGnify](https://www.ebi.ac.uk/metagenomics/),
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format).
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format),
* [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer,
* [Uniclust30](https://uniclust.mmseqs.com/),
* [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer,
* [UniRef90](https://www.uniprot.org/help/uniref).
We provide a script `scripts/download_all_data.sh` that can be used to download
and set up all of these databases:
......@@ -76,9 +90,13 @@ and set up all of these databases:
```
will download a reduced version of the databases to be used with the
`reduced_dbs` preset.
`reduced_dbs` database preset.
We don't provide exactly the versions used in CASP14 -- see the [note on
:ledger: **Note: The download directory `<DOWNLOAD_DIR>` should _not_ be a
subdirectory in the AlphaFold repository directory.** If it is, the Docker build
will be slow as the large databases will be copied during the image creation.
We don't provide exactly the database versions used in CASP14 – see the [note on
reproducibility](#note-on-reproducibility). Some of the databases are mirrored
for speed, see [mirrored databases](#mirrored-databases).
......@@ -87,8 +105,8 @@ and the total size when unzipped is 2.2 TB. Please make sure you have a large
enough hard drive space, bandwidth and time to download. We recommend using an
SSD for better genetic search performance.**
This script will also download the model parameter files. Once the script has
finished, you should have the following directory structure:
The `download_all_data.sh` script will also download the model parameter files.
Once the script has finished, you should have the following directory structure:
```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
......@@ -99,24 +117,29 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
params/ # ~ 3.5 GB (download: 3.5 GB)
# 5 CASP14 models,
# 5 pTM models,
# 5 AlphaFold-Multimer models,
# LICENSE,
# = 11 files.
# = 16 files.
pdb70/ # ~ 56 GB (download: 19.5 GB)
# 9 files.
pdb_mmcif/ # ~ 206 GB (download: 46 GB)
mmcif_files/
# About 180,000 .cif files.
obsolete.dat
pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB)
pdb_seqres.txt
small_bfd/ # ~ 17 GB (download: 9.6 GB)
bfd-first_non_consensus_sequences.fasta
uniclust30/ # ~ 86 GB (download: 24.9 GB)
uniclust30_2018_08/
# 13 files.
uniprot/ # ~ 98.3 GB (download: 49 GB)
uniprot.fasta
uniref90/ # ~ 58 GB (download: 29.7 GB)
uniref90.fasta
```
`bfd/` is only downloaded if you download the full databasees, and `small_bfd/`
`bfd/` is only downloaded if you download the full databases, and `small_bfd/`
is only downloaded if you download the reduced databases.
### Model parameters
......@@ -127,7 +150,7 @@ CC BY-NC 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below
for more detail.
The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar, and
https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for:
......@@ -135,8 +158,46 @@ will download parameters for:
structure prediction quality (see Jumper et al. 2021, Suppl. Methods 1.12
for details).
* 5 pTM models, which were fine-tuned to produce pTM (predicted TM-score) and
predicted aligned error values alongside their structure predictions (see
Jumper et al. 2021, Suppl. Methods 1.9.7 for details).
(PAE) predicted aligned error values alongside their structure predictions
(see Jumper et al. 2021, Suppl. Methods 1.9.7 for details).
* 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their
structure predictions.
### Updating existing AlphaFold installation to include AlphaFold-Multimers
If you have AlphaFold v2.0.0 or v2.0.1 you can either reinstall AlphaFold fully
from scratch (remove everything and run the setup from scratch) or you can do an
incremental update that will be significantly faster but will require a bit more
work. Make sure you follow these steps in the exact order they are listed below:
1. **Update the code.**
* Go to the directory with the cloned AlphaFold repository and run
`git fetch origin main` to get all code updates.
1. **Download the UniProt and PDB seqres databases.**
* Run `scripts/download_uniprot.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/pdb_mmcif`. It is needed to have PDB SeqRes and
PDB from exactly the same date. Failure to do this step will result in
potential errors when searching for templates when running
AlphaFold-Multimer.
* Run `scripts/download_pdb_mmcif.sh <DOWNLOAD_DIR>`.
* Run `scripts/download_pdb_seqres.sh <DOWNLOAD_DIR>`.
1. **Update the model parameters.**
* Remove the old model parameters in `<DOWNLOAD_DIR>/params`.
* Download new model parameters using
`scripts/download_alphafold_params.sh <DOWNLOAD_DIR>`.
1. **Follow [Running AlphaFold](#running-alphafold).**
#### API changes between v2.0.0 and v2.1.0
We tried to keep the API as much backwards compatible as possible, but we had to
change the following:
* The `RunModel.predict()` now needs a `random_seed` argument as MSA sampling
happens inside the Multimer model.
* The `preset` flag in `run_alphafold.py` and `run_docker.py` was split into
`db_preset` and `model_preset`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`.
## Running AlphaFold
......@@ -151,8 +212,6 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
git clone https://github.com/deepmind/alphafold.git
```
1. Modify `DOWNLOAD_DIR` in `docker/run_docker.py` to be the path to the
directory containing the downloaded databases.
1. Build the Docker image:
```bash
......@@ -168,14 +227,19 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
pip3 install -r docker/requirements.txt
```
1. Run `run_docker.py` pointing to a FASTA file containing the protein sequence
for which you wish to predict the structure. If you are predicting the
structure of a protein that is already in PDB and you wish to avoid using it
as a template, then `max_template_date` must be set to be before the release
date of the structure. For example, for the T1050 CASP14 target:
1. Run `run_docker.py` pointing to a FASTA file containing the protein
sequence(s) for which you wish to predict the structure. If you are
predicting the structure of a protein that is already in PDB and you wish to
avoid using it as a template, then `max_template_date` must be set to be
before the release date of the structure. You must also provide the path to
the directory containing the downloaded databases. For example, for the
T1050 CASP14 target:
```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14
python3 docker/run_docker.py \
--fasta_paths=T1050.fasta \
--max_template_date=2020-05-14 \
--data_dir=$DOWNLOAD_DIR
```
By default, Alphafold will attempt to use all visible GPU devices. To use a
......@@ -184,33 +248,76 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details.
1. You can control AlphaFold speed / quality tradeoff by adding
`--preset=reduced_dbs`, `--preset=full_dbs` or `--preset=casp14` to the run
command. We provide the following presets:
1. You can control which AlphaFold model to run by adding the
`--model_preset=` flag. We provide the following models:
* **monomer**: This is the original model used at CASP14 with no ensembling.
* **monomer\_casp14**: This is the original model used at CASP14 with
`num_ensemble=8`, matching our CASP14 configuration. This is largely
provided for reproducibility as it is 8x more computationally
expensive for limited accuracy gain (+0.1 average GDT gain on CASP14
domains).
* **monomer\_ptm**: This is the original CASP14 model fine tuned with the
pTM head, providing a pairwise confidence measure. It is slightly less
accurate than the normal monomer model.
* **multimer**: This is the [AlphaFold-Multimer](#citing-this-work) model.
To use this model, provide a multi-sequence FASTA file. In addition, the
UniProt database should have been downloaded.
* **reduced_dbs**: This preset is optimized for speed and lower hardware
requirements. It runs with a reduced version of the BFD database and
with no ensembling. It requires 8 CPU cores (vCPUs), 8 GB of RAM, and
600 GB of disk space.
* **full_dbs**: The model in this preset is 8 times faster than the
`casp14` preset with a very minor quality drop (-0.1 average GDT drop on
CASP14 domains). It runs with all genetic databases and with no
ensembling.
* **casp14**: This preset uses the same settings as were used in CASP14.
It runs with all genetic databases and with 8 ensemblings.
1. You can control MSA speed/quality tradeoff by adding
`--db_preset=reduced_dbs` or `--db_preset=full_dbs` to the run command. We
provide the following presets:
Running the command above with the `casp14` preset would look like this:
* **reduced\_dbs**: This preset is optimized for speed and lower hardware
requirements. It runs with a reduced version of the BFD database.
It requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space.
* **full\_dbs**: This runs with all genetic databases used at CASP14.
Running the command above with the `monomer` model preset and the
`reduced_dbs` data preset would look like this:
```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 --preset=casp14
python3 docker/run_docker.py \
--fasta_paths=T1050.fasta \
--max_template_date=2020-05-14 \
--model_preset=monomer \
--db_preset=reduced_dbs \
--data_dir=$DOWNLOAD_DIR
```
### Running AlphaFold-Multimer
All steps are the same as when running the monomer system, but you will have to
* provide an input fasta with multiple sequences,
* set `--model_preset=multimer`,
* optionally set the `--is_prokaryote_list` flag with booleans that determine
whether all input sequences in the given fasta file are prokaryotic. If that
is not the case or the origin is unknown, set to `false` for that fasta.
An example that folds two protein complexes `multimer1` and `multimer2` where
the first is prokaryotic and the second isn't:
```bash
python3 docker/run_docker.py \
--fasta_paths=multimer1.fasta,multimer2.fasta \
--is_prokaryote_list=true,false \
--max_template_date=2020-05-14 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
```
### AlphaFold output
The outputs will be in a subfolder of `output_dir` in `run_docker.py`. They
include the computed MSAs, unrelaxed structures, relaxed structures, ranked
structures, raw model outputs, prediction metadata, and section timings. The
`output_dir` directory will have the following structure:
The outputs will be saved in a subdirectory of the directory provided via the
`--output_dir` flag of `run_docker.py` (defaults to `/tmp/alphafold/`). The
outputs include the computed MSAs, unrelaxed structures, relaxed structures,
ranked structures, raw model outputs, prediction metadata, and section timings.
The `--output_dir` directory will have the following structure:
```
<target_name>/
......@@ -299,7 +406,7 @@ develop on top of the `RunModel.predict` method with a parallel system for
precomputing multi-sequence alignments. Alternatively, this script can be run
repeatedly with only moderate overhead.
## Note on reproducibility
## Note on CASP14 reproducibility
AlphaFold's output for a small number of proteins has high inter-run variance,
and may be affected by changes in the input data. The CASP14 target T1064 is a
......@@ -346,6 +453,21 @@ If you use the code or data in this package, please cite:
}
```
In addition, if you use the AlphaFold-Multimer mode, please cite:
```bibtex
@article {AlphaFold-Multimer2021,
author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis},
journal = {bioRxiv}
title = {Protein complex prediction with AlphaFold-Multimer},
year = {2021},
elocation-id = {2021.10.04.463034},
doi = {10.1101/2021.10.04.463034},
URL = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034},
eprint = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034.full.pdf},
}
```
## Community contributions
Colab notebooks provided by the community (please note that these notebooks may
......@@ -378,6 +500,7 @@ and packages:
* [NumPy](https://numpy.org)
* [OpenMM](https://github.com/openmm/openmm)
* [OpenStructure](https://openstructure.org)
* [pandas](https://pandas.pydata.org/)
* [pymol3d](https://github.com/avirshup/py3dmol)
* [SciPy](https://scipy.org)
* [Sonnet](https://github.com/deepmind/sonnet)
......
......@@ -111,8 +111,10 @@ def compute_predicted_aligned_error(
def predicted_tm_score(
logits: np.ndarray,
breaks: np.ndarray,
residue_weights: Optional[np.ndarray] = None) -> np.ndarray:
"""Computes predicted TM alignment score.
residue_weights: Optional[np.ndarray] = None,
asym_id: Optional[np.ndarray] = None,
interface: bool = False) -> np.ndarray:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
logits: [num_res, num_res, num_bins] the logits output from
......@@ -120,9 +122,12 @@ def predicted_tm_score(
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns:
ptm_score: the predicted TM alignment score.
ptm_score: The predicted TM alignment or the predicted iTM score.
"""
# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
......@@ -132,24 +137,32 @@ def predicted_tm_score(
bin_centers = _calculate_bin_centers(breaks)
num_res = np.sum(residue_weights)
num_res = int(np.sum(residue_weights))
# Clip num_res to avoid negative/undefined d0.
clipped_num_res = max(num_res, 19)
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in
# http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
# Yang & Skolnick "Scoring function for automated
# assessment of protein structure template quality" 2004
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
# Convert logits to probs
# Convert logits to probs.
probs = scipy.special.softmax(logits, axis=-1)
# TM-Score term for every bin
# TM-Score term for every bin.
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
# E_distances tm(distance)
# E_distances tm(distance).
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum())
pair_mask = np.ones(shape=(num_res, num_res), dtype=bool)
if interface:
pair_mask *= asym_id[:, None] != asym_id[None, :]
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
pair_residue_weights, axis=-1, keepdims=True))
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
......@@ -23,6 +23,10 @@ import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
@dataclasses.dataclass(frozen=True)
class Protein:
......@@ -43,11 +47,21 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
'because these cannot be written to PDB format.')
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
......@@ -57,9 +71,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
......@@ -73,57 +86,63 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
'Only single chain PDBs are supported when chain_id not specified. '
f'Found {len(chains)} chains.')
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors))
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}')
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
......@@ -143,16 +162,33 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append('MODEL 1')
atom_index = 1
chain_id = 'A'
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(_chain_end(
atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
residue_index[i - 1]))
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
......@@ -168,7 +204,7 @@ def to_pdb(prot: Protein) -> str:
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
......@@ -176,17 +212,15 @@ def to_pdb(prot: Protein) -> str:
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = 'TER'
chain_termination_line = (
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
# Close the final chain.
pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
chain_ids[chain_index[-1]], residue_index[-1]))
pdb_lines.append('ENDMDL')
pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines)
# Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
......@@ -205,25 +239,40 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray:
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(features: FeatureDict, result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values.
Returns:
A protein instance.
"""
fold_output = result['structure_module']
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features['asym_id'])
else:
chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
if b_factors is None:
b_factors = np.zeros_like(fold_output['final_atom_mask'])
return Protein(
aatype=features['aatype'][0],
aatype=_maybe_remove_leading_dim(features['aatype']),
atom_positions=fold_output['final_atom_positions'],
atom_mask=fold_output['final_atom_mask'],
residue_index=features['residue_index'][0] + 1,
residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
chain_index=chain_index,
b_factors=b_factors)
......@@ -35,11 +35,17 @@ class ProteinTest(parameterized.TestCase):
self.assertEqual((num_res,), prot.aatype.shape)
self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
self.assertEqual((num_res,), prot.residue_index.shape)
self.assertEqual((num_res,), prot.chain_index.shape)
self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
@parameterized.parameters(('2rbg.pdb', 'A', 282),
('2rbg.pdb', 'B', 282))
def test_from_pdb_str(self, pdb_file, chain_id, num_res):
@parameterized.named_parameters(
dict(testcase_name='chain_A',
pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1),
dict(testcase_name='chain_B',
pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1),
dict(testcase_name='multichain',
pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2))
def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains):
pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
pdb_file)
with open(pdb_file) as f:
......@@ -49,14 +55,19 @@ class ProteinTest(parameterized.TestCase):
self.assertGreaterEqual(prot.aatype.min(), 0)
# Allow equal since unknown restypes have index equal to restype_num.
self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
self.assertLen(np.unique(prot.chain_index), num_chains)
def test_to_pdb(self):
with open(
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string, chain_id='A')
prot = protein.from_pdb_string(pdb_string)
pdb_string_reconstr = protein.to_pdb(prot)
for line in pdb_string_reconstr.splitlines():
self.assertLen(line, 80)
prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
......@@ -66,6 +77,8 @@ class ProteinTest(parameterized.TestCase):
prot_reconstr.atom_mask, prot.atom_mask)
np.testing.assert_array_equal(
prot_reconstr.residue_index, prot.residue_index)
np.testing.assert_array_equal(
prot_reconstr.chain_index, prot.chain_index)
np.testing.assert_array_almost_equal(
prot_reconstr.b_factors, prot.b_factors)
......@@ -74,9 +87,9 @@ class ProteinTest(parameterized.TestCase):
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string, chain_id='A')
prot = protein.from_pdb_string(pdb_string)
ideal_mask = protein.ideal_atom_mask(prot)
non_ideal_residues = set([102] + list(range(127, 285)))
non_ideal_residues = set([102] + list(range(127, 286)))
for i, (res, atom_mask) in enumerate(
zip(prot.residue_index, prot.atom_mask)):
if res in non_ideal_residues:
......@@ -84,6 +97,18 @@ class ProteinTest(parameterized.TestCase):
else:
self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
def test_too_many_chains(self):
num_res = protein.PDB_MAX_CHAINS + 1
num_atom_type = residue_constants.atom_type_num
with self.assertRaises(ValueError):
_ = protein.Protein(
atom_positions=np.random.random([num_res, num_atom_type, 3]),
aatype=np.random.randint(0, 21, [num_res]),
atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32),
residue_index=np.arange(1, num_res+1),
chain_index=np.arange(num_res),
b_factors=np.random.uniform(1, 100, [num_res]))
if __name__ == '__main__':
absltest.main()
......@@ -16,6 +16,7 @@
import collections
import functools
import os
from typing import List, Mapping, Tuple
import numpy as np
......@@ -398,12 +399,13 @@ def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
residue_bonds: Dict that maps resname -> list of Bond tuples.
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
"""
stereo_chemical_props_path = (
'alphafold/common/stereo_chemical_props.txt')
stereo_chemical_props_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
)
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List
from alphafold.common import residue_constants
from alphafold.data import msa_pairing
from alphafold.data import pipeline
import numpy as np
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
is_prokaryote: bool) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list, prokaryotic=is_prokaryote)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[pipeline.FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> List[pipeline.FeatureDict]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: pipeline.FeatureDict,
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
......@@ -15,6 +15,7 @@
"""Parses the mmCIF file format."""
import collections
import dataclasses
import functools
import io
from typing import Any, Mapping, Optional, Sequence, Tuple
......@@ -160,6 +161,7 @@ def mmcif_loop_to_dict(prefix: str,
return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse(*,
file_id: str,
mmcif_string: str,
......@@ -314,7 +316,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
except ValueError:
logging.warning('Invalid resolution format: %s', parsed_info[res_key])
logging.debug('Invalid resolution format: %s', parsed_info[res_key])
return header
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
uniprot_accession_id: str = ''
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
uniprot_accession_id=matches.group('AccessionIdentifier'),
species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import collections
import functools
import re
import string
from typing import Any, Dict, Iterable, List, Sequence
from alphafold.common import residue_constants
from alphafold.data import pipeline
import numpy as np
import pandas as pd
import scipy.linalg
ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
ALPHANUM_ACCESSION_ID_MAP = {
chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
} # A-Z,0-9
NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
domain_name_pattern = re.compile(
r'''^(?P<pdb>[a-z\d]{4})
\{(?P<bioassembly>[\d+(\+\d+)?])\}
(?P<chain>[a-zA-Z\d]+)
\{(?P<transform_index>\d+)\}$
''', re.VERBOSE)
def create_paired_features(
chains: Iterable[pipeline.FeatureDict],
prokaryotic: bool,
) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(
chains, prokaryotic)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name in ('msa_uniprot_accession_identifiers_all_seq',
'msa_species_identifiers_all_seq'):
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_uniprot_accession_identifiers':
chain_features['msa_uniprot_accession_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_uniprot_accession_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df
return species_lookup
@functools.lru_cache(maxsize=65536)
def encode_accession(accession_id: str) -> int:
"""Map accession codes to the serial order in which they were assigned."""
alpha = ALPHA_ACCESSION_ID_MAP # A-Z
alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9
num = NUM_ACCESSION_ID_MAP # 0-9
coding = 0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if accession_id[0] in {'O', 'P', 'Q'}:
bases = (alpha, num, alphanum, alphanum, alphanum, num)
elif len(accession_id) == 6:
bases = (alpha, num, alpha, alphanum, alphanum, num)
elif len(accession_id) == 10:
bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
alphanum, num)
product = 1
for place, base in zip(reversed(accession_id), reversed(bases)):
coding += base[place] * product
product *= len(base)
return coding
def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
diff_cutoff: int = 20
) -> List[List[Any]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples = []
current_tuple = []
tokens_used_in_answer = set()
def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
def _all_tokens_not_used_before() -> bool:
return all((s not in tokens_used_in_answer for s in current_tuple))
def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
if level == len(accession_id_lists) - 1:
if _all_tokens_not_used_before():
all_accession_tuples.append(list(current_tuple))
for s in current_tuple:
tokens_used_in_answer.add(s)
return
if level == -1:
new_list = accession_id_lists[level+1]
else:
new_list = [(_calc_id_diff(accession_id, s), s) for
s in accession_id_lists[level+1]]
new_list = sorted(new_list)
new_list = [s for d, s in new_list]
for s in new_list:
if (_matches_all_in_current_tuple(s, diff_cutoff) and
s not in tokens_used_in_answer):
current_tuple.append(s)
dfs(level + 1, s)
current_tuple.pop()
dfs(-1, '')
return all_accession_tuples
def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
return matched_df.iloc[0]
def _match_rows_by_genetic_distance(
this_species_msa_dfs: List[pd.DataFrame],
cutoff: int = 20) -> List[List[int]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples = len(this_species_msa_dfs) # N
accession_id_lists = [] # M
match_index_to_chain_index = {}
for chain_index, species_df in enumerate(this_species_msa_dfs):
if species_df is not None:
accession_id_lists.append(
list(species_df.msa_uniprot_accession_identifiers.values))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
all_accession_id_matches = _find_all_accession_matches(
accession_id_lists, cutoff) # [k, M]
all_paired_msa_rows = [] # [k, N]
for accession_id_match in all_accession_id_matches:
paired_msa_rows = []
for match_index, accession_id in enumerate(accession_id_match):
# Map back to chain index.
chain_index = match_index_to_chain_index[match_index]
seq_series = _accession_row(
this_species_msa_dfs[chain_index], accession_id)
if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
seq_series.gap > SEQUENCE_GAP_CUTOFF):
continue
else:
paired_msa_rows.append(seq_series.msa_row)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if len(paired_msa_rows) == len(accession_id_match):
paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
matched_chain_indices = list(match_index_to_chain_index.values())
paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
return all_paired_msa_rows
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict],
prokaryotic: bool) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
continue
# In prokaryotes (and some eukaryotes), interacting genes are often
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if prokaryotic:
paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
if not paired_msa_rows:
continue
else:
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: pipeline.FeatureDict,
np_chains_list: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
dtype=np.int32)
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
dtype=np.int32)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list]
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
return np_example
def _pad_templates(chains: Sequence[pipeline.FeatureDict],
max_templates: int) -> Sequence[pipeline.FeatureDict]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(
chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
return chains
def _concatenate_paired_and_unpaired_features(
example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
for chain in np_chains:
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if tuple(seq) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
if keep_rows:
chain[feature_name] = chain[feature_name][keep_rows]
else:
new_shape = list(chain[feature_name].shape)
new_shape[0] = 0
chain[feature_name] = np.zeros(new_shape,
dtype=chain[feature_name].dtype)
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
......@@ -15,20 +15,47 @@
"""Functions for parsing various file formats."""
import collections
import dataclasses
import itertools
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file."""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Sequence[str]
def __post_init__(self):
if not (len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)):
raise ValueError(
'All fields for an MSA must have the same length. '
f'Got {len(self.sequences)} sequences, '
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
f'{len(self.descriptions)} descriptions.')
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs])
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
sum_probs: float
sum_probs: Optional[float]
query: str
hit_sequence: str
indices_query: List[int]
......@@ -64,9 +91,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions
def parse_stockholm(
stockholm_string: str
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
def parse_stockholm(stockholm_string: str) -> Msa:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
......@@ -121,10 +146,12 @@ def parse_stockholm(
deletion_count = 0
deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys())
return Msa(sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys()))
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
......@@ -138,8 +165,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
sequences, _ = parse_fasta(a3m_string)
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
......@@ -155,7 +183,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
return Msa(sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions)
def _convert_sto_seq_to_a3m(
......@@ -168,7 +198,8 @@ def _convert_sto_seq_to_a3m(
def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None) -> str:
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
......@@ -203,18 +234,138 @@ def convert_stockholm_to_a3m(stockholm_format: str,
# Convert sto format to a3m line by line
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
if remove_first_row_gaps:
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence))
# Dots are optional in a3m format and are commonly removed.
out_sequence = sto_sequence.replace('.', '')
if remove_first_row_gaps:
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
a3m_sequences[seqname] = out_sequence
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
"""Truncates a stockholm file to a maximum number of sequences."""
seqnames = set()
filtered_lines = []
for line in stockholm_msa.splitlines():
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
......@@ -264,8 +415,8 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(prob_true, e_value, _, aligned_cols, _, _, sum_probs,
neff) = [float(x) for x in match.groups()]
(_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x)
for x in match.groups()]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
......@@ -362,3 +513,95 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6])
def parse_hmmsearch_a3m(query_sequence: str,
a3m_string: str,
skip_first: bool = True) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
......@@ -15,19 +15,22 @@
"""Functions for building the input features for the AlphaFold model."""
import os
from typing import Mapping, Optional, Sequence
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
from absl import logging
from alphafold.common import residue_constants
from alphafold.data import msa_identifiers
from alphafold.data import parsers
from alphafold.data import templates
from alphafold.data.tools import hhblits
from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
from alphafold.data.tools import jackhmmer
import numpy as np
# Internal import (7716).
FeatureDict = Mapping[str, np.ndarray]
FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def make_sequence_features(
......@@ -47,55 +50,78 @@ def make_sequence_features(
return features
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')
int_msa = []
deletion_matrix = []
uniprot_accession_ids = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa):
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index])
uniprot_accession_ids.append(
identifiers.uniprot_accession_id.encode('utf-8'))
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32)
features['msa_uniprot_accession_identifiers'] = np.array(
uniprot_accession_ids, dtype=np.object_)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features
def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool,
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
logging.warning('Reading MSA from file %s', msa_out_path)
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
hhsearch_binary_path: str,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
pdb70_database_path: str,
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000):
"""Constructs a feature dict for a given FASTA file."""
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
......@@ -111,12 +137,11 @@ class DataPipeline:
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path])
self.template_searcher = template_searcher
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
......@@ -130,72 +155,68 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
input_fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
with open(uniref90_out_path, 'w') as f:
f.write(jackhmmer_uniref90_result['sto'])
jackhmmer_uniref90_result = run_msa_tool(
self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
'sto', self.use_precomputed_msas)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
with open(mgnify_out_path, 'w') as f:
f.write(jackhmmer_mgnify_result['sto'])
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f:
f.write(hhsearch_result)
jackhmmer_mgnify_result = run_msa_tool(
self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
self.use_precomputed_msas)
msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.truncate_stockholm_msa(
msa_for_templates, max_sequences=self.uniref_max_hits)
msa_for_templates = parsers.deduplicate_stockholm_msa(
msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates)
if self.template_searcher.input_format == 'sto':
pdb_templates_result = self.template_searcher.query(msa_for_templates)
elif self.template_searcher.input_format == 'a3m':
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
else:
raise ValueError('Unrecognized template input format: '
f'{self.template_searcher.input_format}')
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto'])
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_mgnify_result['sto'])
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
pdb_hits_out_path = os.path.join(
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
with open(pdb_hits_out_path, 'w') as f:
f.write(pdb_templates_result)
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
input_fasta_path)[0]
uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto'])
if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path,
'sto', self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(
input_fasta_path)
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
hhblits_bfd_uniclust_result['a3m'])
hhblits_bfd_uniclust_result = run_msa_tool(
self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
'a3m', self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=hhsearch_hits)
hits=pdb_template_hits)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res)
msa_features = make_msa_features(
msas=(uniref90_msa, bfd_msa, mgnify_msa),
deletion_matrices=(uniref90_deletion_matrix,
bfd_deletion_matrix,
mgnify_deletion_matrix))
msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the features for the AlphaFold multimer model."""
import collections
import contextlib
import copy
import dataclasses
import json
import os
import tempfile
from typing import Mapping, MutableMapping, Sequence
from absl import logging
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import feature_processing
from alphafold.data import msa_pairing
from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.data.tools import jackhmmer
import numpy as np
# Internal import (7716).
@dataclasses.dataclass(frozen=True)
class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(*,
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: pipeline.FeatureDict,
chain_id: str) -> pipeline.FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
) -> MutableMapping[str, pipeline.FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: pipeline.DataPipeline,
jackhmmer_binary_path: str,
uniprot_database_path: str,
max_uniprot_hits: int = 50000,
use_precomputed_msas: bool = False):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
self._uniprot_msa_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path)
self._max_uniprot_hits = max_uniprot_hits
self.use_precomputed_msas = use_precomputed_msas
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
msa_output_dir: str,
is_homomer_or_monomer: bool) -> pipeline.FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{description}\n{sequence}\n'
chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
if not os.path.exists(chain_msa_output_dir):
os.makedirs(chain_msa_output_dir)
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
logging.info('Running monomer pipeline on chain %s: %s',
chain_id, description)
chain_features = self._monomer_data_pipeline.process(
input_fasta_path=chain_fasta_path,
msa_output_dir=chain_msa_output_dir)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
result = pipeline.run_msa_tool(
self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
self.use_precomputed_msas)
msa = parsers.parse_stockholm(result['sto'])
msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers',
)
feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats}
return feats
def process(self,
input_fasta_path: str,
msa_output_dir: str,
is_prokaryote: bool = False) -> pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = _make_chain_id_map(sequences=input_seqs,
descriptions=input_descs)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in sequence_features:
all_chain_features[chain_id] = copy.deepcopy(
sequence_features[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer)
chain_features = convert_monomer_features(chain_features,
chain_id=chain_id)
all_chain_features[chain_id] = chain_features
sequence_features[fasta_chain.sequence] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features,
is_prokaryote=is_prokaryote,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
......@@ -13,8 +13,10 @@
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import abc
import dataclasses
import datetime
import functools
import glob
import os
import re
......@@ -71,10 +73,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small."""
......@@ -128,7 +126,6 @@ def _is_after_cutoff(
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.warning('Template structure not in release dates dict: %s', pdb_id)
return False
......@@ -177,7 +174,6 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
......@@ -190,7 +186,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
......@@ -202,7 +197,6 @@ def _assess_hhsearch_hit(
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
......@@ -222,10 +216,6 @@ def _assess_hhsearch_hit(
raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
f'({release_date_cutoff}).')
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError('PDB code identical to Query PDB code.')
if align_ratio <= min_align_ratio:
raise AlignRatioError('Proportion of residues aligned to query too small. '
f'Align ratio: {align_ratio}.')
......@@ -368,8 +358,9 @@ def _realign_pdb_template_to_query(
'protein chain.')
try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence]))
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e:
raise QueryToTemplateAlignError(
'Could not align old template %s to template %s (%s_%s). Error: %s' %
......@@ -472,6 +463,18 @@ def _get_atom_positions(
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1.
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if (res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_positions[res_index] = pos
all_positions_mask[res_index] = mask
_check_residue_distances(
......@@ -673,9 +676,15 @@ class SingleHitResult:
warning: Optional[str]
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
......@@ -702,14 +711,12 @@ def _process_single_hit(
hit=hit,
hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates,
release_date_cutoff=max_template_date)
except PrefilterError as e:
msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
logging.info('%s: %s', query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)):
logging.info(msg)
if strict_error_check and isinstance(e, (DateError, DuplicateError)):
# In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None)
......@@ -724,11 +731,10 @@ def _process_single_hit(
template_sequence = hit.hit_sequence.replace('-', '')
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.info('Reading PDB entry from %s. Query: %s, template: %s',
cif_path, query_sequence, template_sequence)
logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path,
query_sequence, template_sequence)
# Fail if we can't find the mmCIF file.
with open(cif_path, 'r') as cif_file:
cif_string = cif_file.read()
cif_string = _read_file(cif_path)
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string)
......@@ -742,7 +748,7 @@ def _process_single_hit(
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
logging.warning(error)
logging.debug(error)
return SingleHitResult(features=None, error=None, warning=None)
try:
......@@ -754,7 +760,10 @@ def _process_single_hit(
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path)
features['template_sum_probs'] = [hit.sum_probs]
if hit.sum_probs is None:
features['template_sum_probs'] = [0]
else:
features['template_sum_probs'] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
......@@ -765,7 +774,7 @@ def _process_single_hit(
TemplateAtomMaskAllZerosError) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
......@@ -788,8 +797,8 @@ class TemplateSearchResult:
warnings: Sequence[str]
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
class TemplateHitFeaturizer(abc.ABC):
"""An abstract base class for turning template hits to template features."""
def __init__(
self,
......@@ -850,29 +859,28 @@ class TemplateHitFeaturizer:
else:
self._obsolete_pdbs = {}
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence."""
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hhsearch to template features."""
def get_templates(
self,
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_pdb_code)
logging.info('Searching for template for: %s', query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0
errors = []
warnings = []
......@@ -884,10 +892,9 @@ class TemplateHitFeaturizer:
result = _process_single_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
......@@ -920,3 +927,84 @@ class TemplateHitFeaturizer:
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hmmsearch to template features."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
if not hits or hits[0].sum_probs is None:
sorted_hits = hits
else:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
for hit in sorted_hits:
# We got all the templates we wanted, stop processing hits.
if len(already_seen) >= self._max_hits:
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug('Skipped invalid hit %s, error: %s, warning: %s',
hit.name, result.error, result.warning)
else:
already_seen_key = result.features['template_sequence']
if already_seen_key in already_seen:
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
'template_aatype': np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32),
'template_all_atom_masks': np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32),
'template_all_atom_positions': np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32),
'template_domain_names': np.array([''.encode()], dtype=np.object),
'template_sequence': np.array([''.encode()], dtype=np.object),
'template_sum_probs': np.array([0], dtype=np.float32)
}
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
......@@ -17,7 +17,7 @@
import glob
import os
import subprocess
from typing import Any, Mapping, Optional, Sequence
from typing import Any, List, Mapping, Optional, Sequence
from absl import logging
from alphafold.data.tools import utils
......@@ -94,9 +94,9 @@ class HHBlits:
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
db_cmd = []
......@@ -152,4 +152,4 @@ class HHBlits:
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
return raw_output
return [raw_output]
......@@ -21,6 +21,7 @@ from typing import Sequence
from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import utils
# Internal import (7716).
......@@ -55,9 +56,17 @@ class HHSearch:
logging.error('Could not find HHsearch database %s', database_path)
raise ValueError(f'Could not find HHsearch database {database_path}')
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
......@@ -89,3 +98,10 @@ class HHSearch:
with open(hhr_path) as f:
hhr = f.read()
return hhr
def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
del input_sequence # Used by hmmseach but not needed for hhsearch.
return parsers.parse_hhr(output_string)
......@@ -98,7 +98,7 @@ class Hmmbuild(object):
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
......
......@@ -19,6 +19,8 @@ import subprocess
from typing import Optional, Sequence
from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import hmmbuild
from alphafold.data.tools import utils
# Internal import (7716).
......@@ -29,12 +31,15 @@ class Hmmsearch(object):
def __init__(self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
......@@ -42,18 +47,42 @@ class Hmmsearch(object):
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
def query(self, hmm: str) -> str:
@property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
model_construction='hand')
return self.query_with_hmm(hmm)
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m')
out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
......@@ -66,7 +95,7 @@ class Hmmsearch(object):
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', a3m_out_path,
'-A', out_path,
hmm_input_path,
self.database_path,
])
......@@ -84,7 +113,19 @@ class Hmmsearch(object):
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(a3m_out_path) as f:
a3m_out = f.read()
with open(out_path) as f:
out_msa = f.read()
return out_msa
return a3m_out
def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(output_string,
remove_first_row_gaps=False)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False)
return template_hits
......@@ -89,7 +89,7 @@ class Jackhmmer:
def _query_chunk(self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')
# The F1/F2/F3 are the expected proportion to pass each of the filtering
......@@ -192,7 +192,10 @@ class Jackhmmer:
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
# Do not set next_future for the last chunk so that this works even for
# databases with only 1 chunk.
if i < self.num_streamed_chunks:
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
......@@ -70,7 +70,7 @@ class Kalign:
raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from typing import Dict, Text
from alphafold.common import residue_constants
from alphafold.model import geometry
from alphafold.model import utils
import jax
import jax.numpy as jnp
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(residue_constants.restypes):
resname = residue_constants.restype_1to3[restype_letter]
for chi_idx in range(4):
if residue_constants.chi_angles_mask[restype][chi_idx]:
atom_names = residue_constants.chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = residue_constants.atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = residue_constants.chi_angles_mask
def get_atom37_mask(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_MASK), aatype)
def get_atom14_mask(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype)
def get_atom14_is_ambiguous(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_IS_AMBIGUOUS), aatype)
def get_atom14_to_atom37_map(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype)
def get_atom37_to_atom14_map(aatype):
return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_TO_ATOM14), aatype)
def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...)
aatype: jnp.ndarray
) -> jnp.ndarray: # (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert len(atom14_data.shape) in [2, 3]
idx_atom37_to_atom14 = get_atom37_to_atom14_map(aatype)
atom37_data = utils.batched_gather(
atom14_data, idx_atom37_to_atom14, batch_dims=1)
atom37_mask = get_atom37_mask(aatype)
if len(atom14_data.shape) == 2:
atom37_data *= atom37_mask
elif len(atom14_data.shape) == 3:
atom37_data *= atom37_mask[:, :, None].astype(atom37_data.dtype)
return atom37_data
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37 = utils.batched_gather(
jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype)
atom14_mask = utils.batched_gather(
all_atom_mask, residx_atom14_to_atom37, batch_dims=1).astype(jnp.float32)
# create a mask for known groundtruth positions
atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype)
# gather the groundtruth positions
atom14_positions = jax.tree_map(
lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1),
all_atom_pos)
atom14_positions = atom14_mask * atom14_positions
return atom14_positions, atom14_mask
def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = utils.batched_gather(
jnp.asarray(RENAMING_MATRICES), aatype)
alternative_positions = jax.tree_map(
lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)
return alternative_positions, alternative_mask
def atom37_to_frames(
aatype: jnp.ndarray, # (...)
all_atom_positions: geometry.Vec3Array, # (..., 37)
all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[Text, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape = aatype.shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype = jnp.reshape(aatype, [-1])
all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]),
all_atom_positions)
all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = utils.batched_gather(
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype)
# Gather the base atom positions for each rigid group.
base_atom_pos = jax.tree_map(
lambda x: utils.batched_gather( # pylint: disable=g-long-lambda
x, residx_rigidgroup_base_atom37_idx, batch_dims=1),
all_atom_positions)
# Compute the Rigids.
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = geometry.Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = geometry.Rigid3Array(gt_rotation, origin)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.astype(jnp.float32),
residx_rigidgroup_base_atom37_idx,
batch_dims=1)
gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation(
geometry.Rot3Array.from_array(rots))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1])
for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = utils.batched_gather(
restype_rigidgroup_is_ambiguous, aatype)
ambiguity_rot = utils.batched_gather(restype_rigidgroup_rots, aatype)
ambiguity_rot = geometry.Rot3Array.from_array(ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,))
# reshape back to original residue layout
gt_frames = jax.tree_map(fix_shape, gt_frames)
gt_exists = fix_shape(gt_exists)
group_exists = fix_shape(group_exists)
residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous)
alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames)
return {
'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8)
}
def torsion_angles_to_frames(
aatype: jnp.ndarray, # (N)
backb_to_global: geometry.Rigid3Array, # (N)
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
) -> geometry.Rigid3Array: # (N, 8)
"""Compute rigid group frames from torsion angles."""
assert len(aatype.shape) == 1, (
f'Expected array of rank 1, got array with shape: {aatype.shape}.')
assert len(backb_to_global.rotation.shape) == 1, (
f'Expected array of rank 1, got array with shape: '
f'{backb_to_global.rotation.shape}')
assert len(torsion_angles_sin_cos.shape) == 3, (
f'Expected array of rank 3, got array with shape: '
f'{torsion_angles_sin_cos.shape}')
assert torsion_angles_sin_cos.shape[1] == 7, (
f'wrong shape {torsion_angles_sin_cos.shape}')
assert torsion_angles_sin_cos.shape[2] == 2, (
f'wrong shape {torsion_angles_sin_cos.shape}')
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame,
aatype)
default_frames = geometry.Rigid3Array.from_array4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues, = aatype.shape
sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],
axis=-1)
cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],
axis=-1)
zeros = jnp.zeros_like(sin_angles)
ones = jnp.ones_like(sin_angles)
# all_rots are geometry.Rot3Array with shape (N, 8)
all_rots = geometry.Rot3Array(ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles)
# Apply rotations to the frames.
all_frames = default_frames.compose_rotation(all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb = all_frames[:, 4]
chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[:, 5]
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
all_frames_to_backb = jax.tree_multimap(
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
chi4_frame_to_backb[:, None])
# Create the global frames.
# shape (N, 8)
all_frames_to_global = backb_to_global[:, None] @ all_frames_to_backb
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: jnp.ndarray, # (N)
all_frames_to_global: geometry.Rigid3Array # (N, 8)
) -> geometry.Vec3Array: # (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx = utils.batched_gather(
residue_constants.restype_atom14_to_rigid_group, aatype)
group_mask = jax.nn.one_hot(
residx_to_group_idx, num_classes=8) # shape (N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global = jax.tree_map(
lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1),
all_frames_to_global)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions = geometry.Vec3Array.from_array(
utils.batched_gather(
residue_constants.restype_atom14_rigid_group_positions, aatype))
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions = map_atoms_to_global.apply_to_point(lit_positions)
# Mask out non-existing atoms.
mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype)
pred_positions = pred_positions * mask
return pred_positions
def extreme_ca_ca_distance_violations(
positions: geometry.Vec3Array, # (N, 37(14))
mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
max_angstrom_tolerance=1.5
) -> jnp.ndarray:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos = positions[:-1, 1] # (N - 1,)
this_ca_mask = mask[:-1, 1] # (N - 1)
next_ca_pos = positions[1:, 1] # (N - 1,)
next_ca_mask = mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, 1e-6)
violations = (ca_ca_distance -
residue_constants.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return utils.mask_mean(mask=mask, value=violations)
def between_residue_bond_loss(
pred_atom_positions: geometry.Vec3Array, # (N, 37(14))
pred_atom_mask: jnp.ndarray, # (N, 37(14))
residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues."""
assert len(pred_atom_positions.shape) == 2
assert len(pred_atom_mask.shape) == 2
assert len(residue_index.shape) == 1
assert len(aatype.shape) == 1
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[:-1, 1] # (N - 1)
this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1)
this_c_pos = pred_atom_positions[:-1, 2] # (N - 1)
this_c_mask = pred_atom_mask[:-1, 2] # (N - 1)
next_n_pos = pred_atom_positions[1:, 0] # (N - 1)
next_n_mask = pred_atom_mask[1:, 0] # (N - 1)
next_ca_pos = pred_atom_positions[1:, 1] # (N - 1)
next_ca_mask = pred_atom_mask[1:, 1] # (N - 1)
has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype(
jnp.float32)
# Compute loss for the C--N bond.
c_n_bond_length = geometry.euclidean_distance(this_c_pos, next_n_pos, 1e-6)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[1:] == residue_constants.restype_order['P']).astype(jnp.float32)
gt_length = (
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
gt_stddev = (
(1. - next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
c_n_bond_length_error = jnp.sqrt(1e-6 +
jnp.square(c_n_bond_length - gt_length))
c_n_loss_per_residue = jax.nn.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
# Compute loss for the angles.
c_ca_unit_vec = (this_ca_pos - this_c_pos).normalized(1e-6)
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length
n_ca_unit_vec = (next_ca_pos - next_n_pos).normalized(1e-6)
ca_c_n_cos_angle = c_ca_unit_vec.dot(c_n_unit_vec)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
ca_c_n_loss_per_residue = jax.nn.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
c_n_ca_cos_angle = (-c_n_unit_vec).dot(n_ca_unit_vec)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = jax.nn.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
jnp.pad(per_residue_loss_sum, [[1, 0]]))
# Compute hard violations.
violation_mask = jnp.max(
jnp.stack([c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask]), axis=0)
violation_mask = jnp.maximum(
jnp.pad(violation_mask, [[0, 1]]),
jnp.pad(violation_mask, [[1, 0]]))
return {'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask # shape (N)
}
def between_residue_clash_loss(
pred_positions: geometry.Vec3Array, # (N, 14)
atom_exists: jnp.ndarray, # (N, 14)
atom_radius: jnp.ndarray, # (N, 14)
residue_index: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]:
"""Loss to penalize steric clashes between residues."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
assert len(atom_radius.shape) == 2
assert len(residue_index.shape) == 1
# Create the distance matrix.
# (N, N, 14, 14)
dists = geometry.euclidean_distance(pred_positions[:, None, :, None],
pred_positions[None, :, None, :], 1e-10)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (atom_exists[:, None, :, None] * atom_exists[None, :, None, :])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask *= (
residue_index[:, None, None, None] < residue_index[None, :, None, None])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = jax.nn.one_hot(2, num_classes=14)
n_one_hot = jax.nn.one_hot(0, num_classes=14)
neighbour_mask = ((residue_index[:, None, None, None] +
1) == residue_index[None, :, None, None])
c_n_bonds = neighbour_mask * c_one_hot[None, None, :,
None] * n_one_hot[None, None, None, :]
dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG')
cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (cys_sg_one_hot[None, None, :, None] *
cys_sg_one_hot[None, None, None, :])
dists_mask *= (1. - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom_radius[:, None, :, None] + atom_radius[None, :, None, :])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * jax.nn.relu(
dists_lower_bound - overlap_tolerance_soft - dists)
# Compute the mean loss.
# shape ()
mean_loss = (jnp.sum(dists_to_low_error)
/ (1e-6 + jnp.sum(dists_mask)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) +
jnp.sum(dists_to_low_error, axis=[1, 3]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = jnp.maximum(
jnp.max(clash_mask, axis=[0, 2]),
jnp.max(clash_mask, axis=[1, 3]))
return {'mean_loss': mean_loss, # shape ()
'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14)
}
def within_residue_violations(
pred_positions: geometry.Vec3Array, # (N, 14)
atom_exists: jnp.ndarray, # (N, 14)
dists_lower_bound: jnp.ndarray, # (N, 14, 14)
dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0,
) -> Dict[Text, jnp.ndarray]:
"""Find within-residue violations."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
assert len(dists_lower_bound.shape) == 3
assert len(dists_upper_bound.shape) == 3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks = (1. - jnp.eye(14, 14)[None])
dists_masks *= (atom_exists[:, :, None] * atom_exists[:, None, :])
# Distance matrix
# shape (N, 14, 14)
dists = geometry.euclidean_distance(pred_positions[:, :, None],
pred_positions[:, None, :], 1e-10)
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error = jax.nn.relu(
dists_lower_bound + tighten_bounds_for_loss - dists)
dists_to_high_error = jax.nn.relu(
dists + tighten_bounds_for_loss - dists_upper_bound)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = (jnp.sum(loss, axis=1) +
jnp.sum(loss, axis=2))
# Compute the violations mask.
# shape (N, 14, 14)
violations = dists_masks * ((dists < dists_lower_bound) |
(dists > dists_upper_bound))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations = jnp.maximum(
jnp.max(violations, axis=1), jnp.max(violations, axis=2))
return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14)
'per_atom_violations': per_atom_violations # shape (N, 14)
}
def find_optimal_renaming(
gt_positions: geometry.Vec3Array, # (N, 14)
alt_gt_positions: geometry.Vec3Array, # (N, 14)
atom_is_ambiguous: jnp.ndarray, # (N, 14)
gt_exists: jnp.ndarray, # (N, 14)
pred_positions: geometry.Vec3Array, # (N, 14)
) -> jnp.ndarray: # (N):
"""Find optimal renaming for ground truth that maximizes LDDT."""
assert len(gt_positions.shape) == 2
assert len(alt_gt_positions.shape) == 2
assert len(atom_is_ambiguous.shape) == 2
assert len(gt_exists.shape) == 2
assert len(pred_positions.shape) == 2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists = geometry.euclidean_distance(pred_positions[:, None, :, None],
pred_positions[None, :, None, :],
1e-10)
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists = geometry.euclidean_distance(gt_positions[:, None, :, None],
gt_positions[None, :, None, :], 1e-10)
alt_gt_dists = geometry.euclidean_distance(alt_gt_positions[:, None, :, None],
alt_gt_positions[None, :, None, :],
1e-10)
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists))
alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask = (
gt_exists[:, None, :, None] * # rows
atom_is_ambiguous[:, None, :, None] * # rows
gt_exists[None, :, None, :] * # cols
(1. - atom_is_ambiguous[None, :, None, :])) # cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3])
alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32)
return alt_naming_is_better # shape (N)
def frame_aligned_point_error(
pred_frames: geometry.Rigid3Array, # shape (num_frames)
target_frames: geometry.Rigid3Array, # shape (num_frames)
frames_mask: jnp.ndarray, # shape (num_frames)
pred_positions: geometry.Vec3Array, # shape (num_positions)
target_positions: geometry.Vec3Array, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions)
pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons)
l1_clamp_distance: float,
length_scale=20.,
epsilon=1e-4) -> jnp.ndarray: # shape ()
"""Measure point error under different alignements.
Computes error between two structures with B points
under A alignments derived form the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
pair_mask: A (num_frames, num_positions) mask to use in the loss, useful
for separating intra from inter chain losses.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
length_scale: length scale to divide loss by.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame aligned point error.
"""
# For now we do not allow any batch dimensions.
assert len(pred_frames.rotation.shape) == 1
assert len(target_frames.rotation.shape) == 1
assert frames_mask.ndim == 1
assert pred_positions.x.ndim == 1
assert target_positions.x.ndim == 1
assert positions_mask.ndim == 1
# Compute array of predicted positions in the predicted frames.
# geometry.Vec3Array (num_frames, num_positions)
local_pred_pos = pred_frames[:, None].inverse().apply_to_point(
pred_positions[None, :])
# Compute array of target positions in the target frames.
# geometry.Vec3Array (num_frames, num_positions)
local_target_pos = target_frames[:, None].inverse().apply_to_point(
target_positions[None, :])
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist = geometry.euclidean_distance(local_pred_pos, local_target_pos,
epsilon)
clipped_error_dist = jnp.clip(error_dist, 0, l1_clamp_distance)
normed_error = clipped_error_dist / length_scale
normed_error *= jnp.expand_dims(frames_mask, axis=-1)
normed_error *= jnp.expand_dims(positions_mask, axis=-2)
if pair_mask is not None:
normed_error *= pair_mask
mask = (jnp.expand_dims(frames_mask, axis=-1) *
jnp.expand_dims(positions_mask, axis=-2))
if pair_mask is not None:
mask *= pair_mask
normalization_factor = jnp.sum(mask, axis=(-1, -2))
return (jnp.sum(normed_error, axis=(-2, -1)) /
(epsilon + normalization_factor))
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in residue_constants.restypes:
residue_name = residue_constants.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return jnp.asarray(chi_atom_indices)
def compute_chi_angles(positions: geometry.Vec3Array,
mask: geometry.Vec3Array,
aatype: geometry.Vec3Array):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, residue_constants.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, residue_constants.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert positions.shape[-1] == residue_constants.atom_type_num
assert mask.shape[-1] == residue_constants.atom_type_num
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices()
# Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4].
atom_indices = utils.batched_gather(
params=chi_atom_indices, indices=aatype, axis=0)
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms = jax.tree_map(
lambda x: utils.batched_gather( # pylint: disable=g-long-lambda
params=x, indices=atom_indices, axis=-1, batch_dims=1), positions)
a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)]
chi_angles = geometry.dihedral_angle(a, b, c, d)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = jnp.asarray(chi_angles_mask)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,
axis=0)
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask = utils.batched_gather(
params=mask, indices=atom_indices, axis=-1, batch_dims=1)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])
chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32)
return chi_angles, chi_mask
def make_transform_from_reference(
a_xyz: geometry.Vec3Array,
b_xyz: geometry.Vec3Array,
c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz,
a_xyz - b_xyz)
return geometry.Rigid3Array(rotation, b_xyz)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment