Commit 0be2b30b authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Add code for AlphaFold-Multimer.

PiperOrigin-RevId: 407076987
parent 1d43aaff
...@@ -7,10 +7,17 @@ v2.0. This is a completely new model that was entered in CASP14 and published in ...@@ -7,10 +7,17 @@ v2.0. This is a completely new model that was entered in CASP14 and published in
Nature. For simplicity, we refer to this model as AlphaFold throughout the rest Nature. For simplicity, we refer to this model as AlphaFold throughout the rest
of this document. of this document.
Any publication that discloses findings arising from using this source code or We also provide an implementation of AlphaFold-Multimer. This represents a work
the model parameters should [cite](#citing-this-work) the in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2). Please also refer AlphaFold system.
to the [Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers)
for how to upgrade and update code.
Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if
applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).
Please also refer to the
[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf) [Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)
for a detailed description of the method. for a detailed description of the method.
...@@ -45,18 +52,25 @@ The following steps are required in order to run AlphaFold: ...@@ -45,18 +52,25 @@ The following steps are required in order to run AlphaFold:
or take a look at the following or take a look at the following
[NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573). [NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).
If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the
third party Singularity setups as linked in
https://github.com/deepmind/alphafold/issues/10 or
https://github.com/deepmind/alphafold/issues/24.
### Genetic databases ### Genetic databases
This step requires `aria2c` to be installed on your machine. This step requires `aria2c` to be installed on your machine.
AlphaFold needs multiple genetic (sequence) databases to run: AlphaFold needs multiple genetic (sequence) databases to run:
* [UniRef90](https://www.uniprot.org/help/uniref),
* [MGnify](https://www.ebi.ac.uk/metagenomics/),
* [BFD](https://bfd.mmseqs.com/), * [BFD](https://bfd.mmseqs.com/),
* [Uniclust30](https://uniclust.mmseqs.com/), * [MGnify](https://www.ebi.ac.uk/metagenomics/),
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/), * [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format). * [PDB](https://www.rcsb.org/) (structures in the mmCIF format),
* [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer,
* [Uniclust30](https://uniclust.mmseqs.com/),
* [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer,
* [UniRef90](https://www.uniprot.org/help/uniref).
We provide a script `scripts/download_all_data.sh` that can be used to download We provide a script `scripts/download_all_data.sh` that can be used to download
and set up all of these databases: and set up all of these databases:
...@@ -76,9 +90,13 @@ and set up all of these databases: ...@@ -76,9 +90,13 @@ and set up all of these databases:
``` ```
will download a reduced version of the databases to be used with the will download a reduced version of the databases to be used with the
`reduced_dbs` preset. `reduced_dbs` database preset.
We don't provide exactly the versions used in CASP14 -- see the [note on :ledger: **Note: The download directory `<DOWNLOAD_DIR>` should _not_ be a
subdirectory in the AlphaFold repository directory.** If it is, the Docker build
will be slow as the large databases will be copied during the image creation.
We don't provide exactly the database versions used in CASP14 – see the [note on
reproducibility](#note-on-reproducibility). Some of the databases are mirrored reproducibility](#note-on-reproducibility). Some of the databases are mirrored
for speed, see [mirrored databases](#mirrored-databases). for speed, see [mirrored databases](#mirrored-databases).
...@@ -87,8 +105,8 @@ and the total size when unzipped is 2.2 TB. Please make sure you have a large ...@@ -87,8 +105,8 @@ and the total size when unzipped is 2.2 TB. Please make sure you have a large
enough hard drive space, bandwidth and time to download. We recommend using an enough hard drive space, bandwidth and time to download. We recommend using an
SSD for better genetic search performance.** SSD for better genetic search performance.**
This script will also download the model parameter files. Once the script has The `download_all_data.sh` script will also download the model parameter files.
finished, you should have the following directory structure: Once the script has finished, you should have the following directory structure:
``` ```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
...@@ -99,24 +117,29 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) ...@@ -99,24 +117,29 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
params/ # ~ 3.5 GB (download: 3.5 GB) params/ # ~ 3.5 GB (download: 3.5 GB)
# 5 CASP14 models, # 5 CASP14 models,
# 5 pTM models, # 5 pTM models,
# 5 AlphaFold-Multimer models,
# LICENSE, # LICENSE,
# = 11 files. # = 16 files.
pdb70/ # ~ 56 GB (download: 19.5 GB) pdb70/ # ~ 56 GB (download: 19.5 GB)
# 9 files. # 9 files.
pdb_mmcif/ # ~ 206 GB (download: 46 GB) pdb_mmcif/ # ~ 206 GB (download: 46 GB)
mmcif_files/ mmcif_files/
# About 180,000 .cif files. # About 180,000 .cif files.
obsolete.dat obsolete.dat
pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB)
pdb_seqres.txt
small_bfd/ # ~ 17 GB (download: 9.6 GB) small_bfd/ # ~ 17 GB (download: 9.6 GB)
bfd-first_non_consensus_sequences.fasta bfd-first_non_consensus_sequences.fasta
uniclust30/ # ~ 86 GB (download: 24.9 GB) uniclust30/ # ~ 86 GB (download: 24.9 GB)
uniclust30_2018_08/ uniclust30_2018_08/
# 13 files. # 13 files.
uniprot/ # ~ 98.3 GB (download: 49 GB)
uniprot.fasta
uniref90/ # ~ 58 GB (download: 29.7 GB) uniref90/ # ~ 58 GB (download: 29.7 GB)
uniref90.fasta uniref90.fasta
``` ```
`bfd/` is only downloaded if you download the full databasees, and `small_bfd/` `bfd/` is only downloaded if you download the full databases, and `small_bfd/`
is only downloaded if you download the reduced databases. is only downloaded if you download the reduced databases.
### Model parameters ### Model parameters
...@@ -127,7 +150,7 @@ CC BY-NC 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below ...@@ -127,7 +150,7 @@ CC BY-NC 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below
for more detail. for more detail.
The AlphaFold parameters are available from The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar, and https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for: will download parameters for:
...@@ -135,8 +158,46 @@ will download parameters for: ...@@ -135,8 +158,46 @@ will download parameters for:
structure prediction quality (see Jumper et al. 2021, Suppl. Methods 1.12 structure prediction quality (see Jumper et al. 2021, Suppl. Methods 1.12
for details). for details).
* 5 pTM models, which were fine-tuned to produce pTM (predicted TM-score) and * 5 pTM models, which were fine-tuned to produce pTM (predicted TM-score) and
predicted aligned error values alongside their structure predictions (see (PAE) predicted aligned error values alongside their structure predictions
Jumper et al. 2021, Suppl. Methods 1.9.7 for details). (see Jumper et al. 2021, Suppl. Methods 1.9.7 for details).
* 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their
structure predictions.
### Updating existing AlphaFold installation to include AlphaFold-Multimers
If you have AlphaFold v2.0.0 or v2.0.1 you can either reinstall AlphaFold fully
from scratch (remove everything and run the setup from scratch) or you can do an
incremental update that will be significantly faster but will require a bit more
work. Make sure you follow these steps in the exact order they are listed below:
1. **Update the code.**
* Go to the directory with the cloned AlphaFold repository and run
`git fetch origin main` to get all code updates.
1. **Download the UniProt and PDB seqres databases.**
* Run `scripts/download_uniprot.sh <DOWNLOAD_DIR>`.
* Remove `<DOWNLOAD_DIR>/pdb_mmcif`. It is needed to have PDB SeqRes and
PDB from exactly the same date. Failure to do this step will result in
potential errors when searching for templates when running
AlphaFold-Multimer.
* Run `scripts/download_pdb_mmcif.sh <DOWNLOAD_DIR>`.
* Run `scripts/download_pdb_seqres.sh <DOWNLOAD_DIR>`.
1. **Update the model parameters.**
* Remove the old model parameters in `<DOWNLOAD_DIR>/params`.
* Download new model parameters using
`scripts/download_alphafold_params.sh <DOWNLOAD_DIR>`.
1. **Follow [Running AlphaFold](#running-alphafold).**
#### API changes between v2.0.0 and v2.1.0
We tried to keep the API as much backwards compatible as possible, but we had to
change the following:
* The `RunModel.predict()` now needs a `random_seed` argument as MSA sampling
happens inside the Multimer model.
* The `preset` flag in `run_alphafold.py` and `run_docker.py` was split into
`db_preset` and `model_preset`.
* Setting the `data_dir` flag is now needed when using `run_docker.py`.
## Running AlphaFold ## Running AlphaFold
...@@ -151,8 +212,6 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -151,8 +212,6 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
git clone https://github.com/deepmind/alphafold.git git clone https://github.com/deepmind/alphafold.git
``` ```
1. Modify `DOWNLOAD_DIR` in `docker/run_docker.py` to be the path to the
directory containing the downloaded databases.
1. Build the Docker image: 1. Build the Docker image:
```bash ```bash
...@@ -168,14 +227,19 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -168,14 +227,19 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
pip3 install -r docker/requirements.txt pip3 install -r docker/requirements.txt
``` ```
1. Run `run_docker.py` pointing to a FASTA file containing the protein sequence 1. Run `run_docker.py` pointing to a FASTA file containing the protein
for which you wish to predict the structure. If you are predicting the sequence(s) for which you wish to predict the structure. If you are
structure of a protein that is already in PDB and you wish to avoid using it predicting the structure of a protein that is already in PDB and you wish to
as a template, then `max_template_date` must be set to be before the release avoid using it as a template, then `max_template_date` must be set to be
date of the structure. For example, for the T1050 CASP14 target: before the release date of the structure. You must also provide the path to
the directory containing the downloaded databases. For example, for the
T1050 CASP14 target:
```bash ```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 python3 docker/run_docker.py \
--fasta_paths=T1050.fasta \
--max_template_date=2020-05-14 \
--data_dir=$DOWNLOAD_DIR
``` ```
By default, Alphafold will attempt to use all visible GPU devices. To use a By default, Alphafold will attempt to use all visible GPU devices. To use a
...@@ -184,33 +248,76 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -184,33 +248,76 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration) [GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details. for more details.
1. You can control AlphaFold speed / quality tradeoff by adding 1. You can control which AlphaFold model to run by adding the
`--preset=reduced_dbs`, `--preset=full_dbs` or `--preset=casp14` to the run `--model_preset=` flag. We provide the following models:
command. We provide the following presets:
* **monomer**: This is the original model used at CASP14 with no ensembling.
* **monomer\_casp14**: This is the original model used at CASP14 with
`num_ensemble=8`, matching our CASP14 configuration. This is largely
provided for reproducibility as it is 8x more computationally
expensive for limited accuracy gain (+0.1 average GDT gain on CASP14
domains).
* **monomer\_ptm**: This is the original CASP14 model fine tuned with the
pTM head, providing a pairwise confidence measure. It is slightly less
accurate than the normal monomer model.
* **multimer**: This is the [AlphaFold-Multimer](#citing-this-work) model.
To use this model, provide a multi-sequence FASTA file. In addition, the
UniProt database should have been downloaded.
* **reduced_dbs**: This preset is optimized for speed and lower hardware 1. You can control MSA speed/quality tradeoff by adding
requirements. It runs with a reduced version of the BFD database and `--db_preset=reduced_dbs` or `--db_preset=full_dbs` to the run command. We
with no ensembling. It requires 8 CPU cores (vCPUs), 8 GB of RAM, and provide the following presets:
600 GB of disk space.
* **full_dbs**: The model in this preset is 8 times faster than the
`casp14` preset with a very minor quality drop (-0.1 average GDT drop on
CASP14 domains). It runs with all genetic databases and with no
ensembling.
* **casp14**: This preset uses the same settings as were used in CASP14.
It runs with all genetic databases and with 8 ensemblings.
Running the command above with the `casp14` preset would look like this: * **reduced\_dbs**: This preset is optimized for speed and lower hardware
requirements. It runs with a reduced version of the BFD database.
It requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space.
* **full\_dbs**: This runs with all genetic databases used at CASP14.
Running the command above with the `monomer` model preset and the
`reduced_dbs` data preset would look like this:
```bash ```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 --preset=casp14 python3 docker/run_docker.py \
--fasta_paths=T1050.fasta \
--max_template_date=2020-05-14 \
--model_preset=monomer \
--db_preset=reduced_dbs \
--data_dir=$DOWNLOAD_DIR
``` ```
### Running AlphaFold-Multimer
All steps are the same as when running the monomer system, but you will have to
* provide an input fasta with multiple sequences,
* set `--model_preset=multimer`,
* optionally set the `--is_prokaryote_list` flag with booleans that determine
whether all input sequences in the given fasta file are prokaryotic. If that
is not the case or the origin is unknown, set to `false` for that fasta.
An example that folds two protein complexes `multimer1` and `multimer2` where
the first is prokaryotic and the second isn't:
```bash
python3 docker/run_docker.py \
--fasta_paths=multimer1.fasta,multimer2.fasta \
--is_prokaryote_list=true,false \
--max_template_date=2020-05-14 \
--model_preset=multimer \
--data_dir=$DOWNLOAD_DIR
```
### AlphaFold output ### AlphaFold output
The outputs will be in a subfolder of `output_dir` in `run_docker.py`. They The outputs will be saved in a subdirectory of the directory provided via the
include the computed MSAs, unrelaxed structures, relaxed structures, ranked `--output_dir` flag of `run_docker.py` (defaults to `/tmp/alphafold/`). The
structures, raw model outputs, prediction metadata, and section timings. The outputs include the computed MSAs, unrelaxed structures, relaxed structures,
`output_dir` directory will have the following structure: ranked structures, raw model outputs, prediction metadata, and section timings.
The `--output_dir` directory will have the following structure:
``` ```
<target_name>/ <target_name>/
...@@ -299,7 +406,7 @@ develop on top of the `RunModel.predict` method with a parallel system for ...@@ -299,7 +406,7 @@ develop on top of the `RunModel.predict` method with a parallel system for
precomputing multi-sequence alignments. Alternatively, this script can be run precomputing multi-sequence alignments. Alternatively, this script can be run
repeatedly with only moderate overhead. repeatedly with only moderate overhead.
## Note on reproducibility ## Note on CASP14 reproducibility
AlphaFold's output for a small number of proteins has high inter-run variance, AlphaFold's output for a small number of proteins has high inter-run variance,
and may be affected by changes in the input data. The CASP14 target T1064 is a and may be affected by changes in the input data. The CASP14 target T1064 is a
...@@ -346,6 +453,21 @@ If you use the code or data in this package, please cite: ...@@ -346,6 +453,21 @@ If you use the code or data in this package, please cite:
} }
``` ```
In addition, if you use the AlphaFold-Multimer mode, please cite:
```bibtex
@article {AlphaFold-Multimer2021,
author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis},
journal = {bioRxiv}
title = {Protein complex prediction with AlphaFold-Multimer},
year = {2021},
elocation-id = {2021.10.04.463034},
doi = {10.1101/2021.10.04.463034},
URL = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034},
eprint = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034.full.pdf},
}
```
## Community contributions ## Community contributions
Colab notebooks provided by the community (please note that these notebooks may Colab notebooks provided by the community (please note that these notebooks may
...@@ -378,6 +500,7 @@ and packages: ...@@ -378,6 +500,7 @@ and packages:
* [NumPy](https://numpy.org) * [NumPy](https://numpy.org)
* [OpenMM](https://github.com/openmm/openmm) * [OpenMM](https://github.com/openmm/openmm)
* [OpenStructure](https://openstructure.org) * [OpenStructure](https://openstructure.org)
* [pandas](https://pandas.pydata.org/)
* [pymol3d](https://github.com/avirshup/py3dmol) * [pymol3d](https://github.com/avirshup/py3dmol)
* [SciPy](https://scipy.org) * [SciPy](https://scipy.org)
* [Sonnet](https://github.com/deepmind/sonnet) * [Sonnet](https://github.com/deepmind/sonnet)
......
...@@ -111,8 +111,10 @@ def compute_predicted_aligned_error( ...@@ -111,8 +111,10 @@ def compute_predicted_aligned_error(
def predicted_tm_score( def predicted_tm_score(
logits: np.ndarray, logits: np.ndarray,
breaks: np.ndarray, breaks: np.ndarray,
residue_weights: Optional[np.ndarray] = None) -> np.ndarray: residue_weights: Optional[np.ndarray] = None,
"""Computes predicted TM alignment score. asym_id: Optional[np.ndarray] = None,
interface: bool = False) -> np.ndarray:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args: Args:
logits: [num_res, num_res, num_bins] the logits output from logits: [num_res, num_res, num_bins] the logits output from
...@@ -120,9 +122,12 @@ def predicted_tm_score( ...@@ -120,9 +122,12 @@ def predicted_tm_score(
breaks: [num_bins] the error bins. breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the residue_weights: [num_res] the per residue weights to use for the
expectation. expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns: Returns:
ptm_score: the predicted TM alignment score. ptm_score: The predicted TM alignment or the predicted iTM score.
""" """
# residue_weights has to be in [0, 1], but can be floating-point, i.e. the # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
...@@ -132,24 +137,32 @@ def predicted_tm_score( ...@@ -132,24 +137,32 @@ def predicted_tm_score(
bin_centers = _calculate_bin_centers(breaks) bin_centers = _calculate_bin_centers(breaks)
num_res = np.sum(residue_weights) num_res = int(np.sum(residue_weights))
# Clip num_res to avoid negative/undefined d0. # Clip num_res to avoid negative/undefined d0.
clipped_num_res = max(num_res, 19) clipped_num_res = max(num_res, 19)
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf # "Scoring function for automated assessment of protein structure template
# Yang & Skolnick "Scoring function for automated # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
# assessment of protein structure template quality" 2004
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
# Convert logits to probs # Convert logits to probs.
probs = scipy.special.softmax(logits, axis=-1) probs = scipy.special.softmax(logits, axis=-1)
# TM-Score term for every bin # TM-Score term for every bin.
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
# E_distances tm(distance) # E_distances tm(distance).
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) pair_mask = np.ones(shape=(num_res, num_res), dtype=bool)
if interface:
pair_mask *= asym_id[:, None] != asym_id[None, :]
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
pair_residue_weights, axis=-1, keepdims=True))
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
...@@ -23,6 +23,10 @@ import numpy as np ...@@ -23,6 +23,10 @@ import numpy as np
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict. ModelOutput = Mapping[str, Any] # Is a nested dict.
# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Protein: class Protein:
...@@ -43,11 +47,21 @@ class Protein: ...@@ -43,11 +47,21 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res] residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units), # B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean # representing the displacement of the residue from its ground truth mean
# value. # value.
b_factors: np.ndarray # [num_res, num_atom_type] b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
'because these cannot be written to PDB format.')
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object. """Takes a PDB string and constructs a Protein object.
...@@ -57,9 +71,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -57,9 +71,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args: Args:
pdb_str: The contents of the pdb file pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which chain_id: If chain_id is specified (e.g. A), then only that chain
will be parsed). If chain_id is specified (e.g. A), then only that chain is parsed. Otherwise all chains are parsed.
is parsed.
Returns: Returns:
A new `Protein` parsed from the pdb contents. A new `Protein` parsed from the pdb contents.
...@@ -73,57 +86,63 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -73,57 +86,63 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
f'Only single model PDBs are supported. Found {len(models)} models.') f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0] model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
'Only single chain PDBs are supported when chain_id not specified. '
f'Found {len(chains)} chains.')
else:
chain = chains[0]
atom_positions = [] atom_positions = []
aatype = [] aatype = []
atom_mask = [] atom_mask = []
residue_index = [] residue_index = []
chain_ids = []
b_factors = [] b_factors = []
for res in chain: for chain in model:
if res.id[2] != ' ': if chain_id is not None and chain.id != chain_id:
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue continue
aatype.append(restype_idx) for res in chain:
atom_positions.append(pos) if res.id[2] != ' ':
atom_mask.append(mask) raise ValueError(
residue_index.append(res.id[1]) f'PDB contains an insertion code at chain {chain.id} and residue '
b_factors.append(res_b_factors) f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein( return Protein(
atom_positions=np.array(atom_positions), atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask), atom_mask=np.array(atom_mask),
aatype=np.array(aatype), aatype=np.array(aatype),
residue_index=np.array(residue_index), residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors)) b_factors=np.array(b_factors))
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}')
def to_pdb(prot: Protein) -> str: def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string. """Converts a `Protein` instance to a PDB string.
...@@ -143,16 +162,33 @@ def to_pdb(prot: Protein) -> str: ...@@ -143,16 +162,33 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype aatype = prot.aatype
atom_positions = prot.atom_positions atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32) residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num): if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.') raise ValueError('Invalid aatypes.')
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append('MODEL 1') pdb_lines.append('MODEL 1')
atom_index = 1 atom_index = 1
chain_id = 'A' last_chain_index = chain_index[0]
# Add all atom sites. # Add all atom sites.
for i in range(aatype.shape[0]): for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(_chain_end(
atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
residue_index[i - 1]))
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i]) res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip( for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]): atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
...@@ -168,7 +204,7 @@ def to_pdb(prot: Protein) -> str: ...@@ -168,7 +204,7 @@ def to_pdb(prot: Protein) -> str:
charge = '' charge = ''
# PDB is a columnar format, every space matters here! # PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}' f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} ' f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} ' f'{occupancy:>6.2f}{b_factor:>6.2f} '
...@@ -176,17 +212,15 @@ def to_pdb(prot: Protein) -> str: ...@@ -176,17 +212,15 @@ def to_pdb(prot: Protein) -> str:
pdb_lines.append(atom_line) pdb_lines.append(atom_line)
atom_index += 1 atom_index += 1
# Close the chain. # Close the final chain.
chain_end = 'TER' pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
chain_termination_line = ( chain_ids[chain_index[-1]], residue_index[-1]))
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
pdb_lines.append('ENDMDL') pdb_lines.append('ENDMDL')
pdb_lines.append('END') pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines) # Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray: def ideal_atom_mask(prot: Protein) -> np.ndarray:
...@@ -205,25 +239,40 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray: ...@@ -205,25 +239,40 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray:
return residue_constants.STANDARD_ATOM_MASK[prot.aatype] return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(features: FeatureDict, result: ModelOutput, def from_prediction(
b_factors: Optional[np.ndarray] = None) -> Protein: features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True) -> Protein:
"""Assembles a protein from a prediction. """Assembles a protein from a prediction.
Args: Args:
features: Dictionary holding model inputs. features: Dictionary holding model inputs.
result: Dictionary holding model outputs. result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein. b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values.
Returns: Returns:
A protein instance. A protein instance.
""" """
fold_output = result['structure_module'] fold_output = result['structure_module']
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features['asym_id'])
else:
chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
if b_factors is None: if b_factors is None:
b_factors = np.zeros_like(fold_output['final_atom_mask']) b_factors = np.zeros_like(fold_output['final_atom_mask'])
return Protein( return Protein(
aatype=features['aatype'][0], aatype=_maybe_remove_leading_dim(features['aatype']),
atom_positions=fold_output['final_atom_positions'], atom_positions=fold_output['final_atom_positions'],
atom_mask=fold_output['final_atom_mask'], atom_mask=fold_output['final_atom_mask'],
residue_index=features['residue_index'][0] + 1, residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
chain_index=chain_index,
b_factors=b_factors) b_factors=b_factors)
...@@ -35,11 +35,17 @@ class ProteinTest(parameterized.TestCase): ...@@ -35,11 +35,17 @@ class ProteinTest(parameterized.TestCase):
self.assertEqual((num_res,), prot.aatype.shape) self.assertEqual((num_res,), prot.aatype.shape)
self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
self.assertEqual((num_res,), prot.residue_index.shape) self.assertEqual((num_res,), prot.residue_index.shape)
self.assertEqual((num_res,), prot.chain_index.shape)
self.assertEqual((num_res, num_atoms), prot.b_factors.shape) self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
@parameterized.parameters(('2rbg.pdb', 'A', 282), @parameterized.named_parameters(
('2rbg.pdb', 'B', 282)) dict(testcase_name='chain_A',
def test_from_pdb_str(self, pdb_file, chain_id, num_res): pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1),
dict(testcase_name='chain_B',
pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1),
dict(testcase_name='multichain',
pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2))
def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains):
pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
pdb_file) pdb_file)
with open(pdb_file) as f: with open(pdb_file) as f:
...@@ -49,14 +55,19 @@ class ProteinTest(parameterized.TestCase): ...@@ -49,14 +55,19 @@ class ProteinTest(parameterized.TestCase):
self.assertGreaterEqual(prot.aatype.min(), 0) self.assertGreaterEqual(prot.aatype.min(), 0)
# Allow equal since unknown restypes have index equal to restype_num. # Allow equal since unknown restypes have index equal to restype_num.
self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
self.assertLen(np.unique(prot.chain_index), num_chains)
def test_to_pdb(self): def test_to_pdb(self):
with open( with open(
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f: '2rbg.pdb')) as f:
pdb_string = f.read() pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string, chain_id='A') prot = protein.from_pdb_string(pdb_string)
pdb_string_reconstr = protein.to_pdb(prot) pdb_string_reconstr = protein.to_pdb(prot)
for line in pdb_string_reconstr.splitlines():
self.assertLen(line, 80)
prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
...@@ -66,6 +77,8 @@ class ProteinTest(parameterized.TestCase): ...@@ -66,6 +77,8 @@ class ProteinTest(parameterized.TestCase):
prot_reconstr.atom_mask, prot.atom_mask) prot_reconstr.atom_mask, prot.atom_mask)
np.testing.assert_array_equal( np.testing.assert_array_equal(
prot_reconstr.residue_index, prot.residue_index) prot_reconstr.residue_index, prot.residue_index)
np.testing.assert_array_equal(
prot_reconstr.chain_index, prot.chain_index)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
prot_reconstr.b_factors, prot.b_factors) prot_reconstr.b_factors, prot.b_factors)
...@@ -74,9 +87,9 @@ class ProteinTest(parameterized.TestCase): ...@@ -74,9 +87,9 @@ class ProteinTest(parameterized.TestCase):
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f: '2rbg.pdb')) as f:
pdb_string = f.read() pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string, chain_id='A') prot = protein.from_pdb_string(pdb_string)
ideal_mask = protein.ideal_atom_mask(prot) ideal_mask = protein.ideal_atom_mask(prot)
non_ideal_residues = set([102] + list(range(127, 285))) non_ideal_residues = set([102] + list(range(127, 286)))
for i, (res, atom_mask) in enumerate( for i, (res, atom_mask) in enumerate(
zip(prot.residue_index, prot.atom_mask)): zip(prot.residue_index, prot.atom_mask)):
if res in non_ideal_residues: if res in non_ideal_residues:
...@@ -84,6 +97,18 @@ class ProteinTest(parameterized.TestCase): ...@@ -84,6 +97,18 @@ class ProteinTest(parameterized.TestCase):
else: else:
self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
def test_too_many_chains(self):
num_res = protein.PDB_MAX_CHAINS + 1
num_atom_type = residue_constants.atom_type_num
with self.assertRaises(ValueError):
_ = protein.Protein(
atom_positions=np.random.random([num_res, num_atom_type, 3]),
aatype=np.random.randint(0, 21, [num_res]),
atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32),
residue_index=np.arange(1, num_res+1),
chain_index=np.arange(num_res),
b_factors=np.random.uniform(1, 100, [num_res]))
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import collections import collections
import functools import functools
import os
from typing import List, Mapping, Tuple from typing import List, Mapping, Tuple
import numpy as np import numpy as np
...@@ -398,12 +399,13 @@ def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], ...@@ -398,12 +399,13 @@ def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
("residue_virtual_bonds"). ("residue_virtual_bonds").
Returns: Returns:
residue_bonds: dict that maps resname --> list of Bond tuples residue_bonds: Dict that maps resname -> list of Bond tuples.
residue_virtual_bonds: dict that maps resname --> list of Bond tuples residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
residue_bond_angles: dict that maps resname --> list of BondAngle tuples residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
""" """
stereo_chemical_props_path = ( stereo_chemical_props_path = os.path.join(
'alphafold/common/stereo_chemical_props.txt') os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
)
with open(stereo_chemical_props_path, 'rt') as f: with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read() stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines()) lines_iter = iter(stereo_chemical_props.splitlines())
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List
from alphafold.common import residue_constants
from alphafold.data import msa_pairing
from alphafold.data import pipeline
import numpy as np
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
is_prokaryote: bool) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list, prokaryotic=is_prokaryote)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[pipeline.FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> List[pipeline.FeatureDict]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: pipeline.FeatureDict,
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Parses the mmCIF file format.""" """Parses the mmCIF file format."""
import collections import collections
import dataclasses import dataclasses
import functools
import io import io
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
...@@ -160,6 +161,7 @@ def mmcif_loop_to_dict(prefix: str, ...@@ -160,6 +161,7 @@ def mmcif_loop_to_dict(prefix: str,
return {entry[index]: entry for entry in entries} return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse(*, def parse(*,
file_id: str, file_id: str,
mmcif_string: str, mmcif_string: str,
...@@ -314,7 +316,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -314,7 +316,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution) header['resolution'] = float(raw_resolution)
except ValueError: except ValueError:
logging.warning('Invalid resolution format: %s', parsed_info[res_key]) logging.debug('Invalid resolution format: %s', parsed_info[res_key])
return header return header
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
uniprot_accession_id: str = ''
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
uniprot_accession_id=matches.group('AccessionIdentifier'),
species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
This diff is collapsed.
...@@ -15,20 +15,47 @@ ...@@ -15,20 +15,47 @@
"""Functions for parsing various file formats.""" """Functions for parsing various file formats."""
import collections import collections
import dataclasses import dataclasses
import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file."""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Sequence[str]
def __post_init__(self):
if not (len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)):
raise ValueError(
'All fields for an MSA must have the same length. '
f'Got {len(self.sequences)} sequences, '
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
f'{len(self.descriptions)} descriptions.')
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs])
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
sum_probs: float sum_probs: Optional[float]
query: str query: str
hit_sequence: str hit_sequence: str
indices_query: List[int] indices_query: List[int]
...@@ -64,9 +91,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -64,9 +91,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm( def parse_stockholm(stockholm_string: str) -> Msa:
stockholm_string: str
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -121,10 +146,12 @@ def parse_stockholm( ...@@ -121,10 +146,12 @@ def parse_stockholm(
deletion_count = 0 deletion_count = 0
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys()) return Msa(sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys()))
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment. """Parses sequences and deletion matrix from a3m format alignment.
Args: Args:
...@@ -138,8 +165,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -138,8 +165,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
* The deletion matrix for the alignment as a list of lists. The element * The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j. the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
""" """
sequences, _ = parse_fasta(a3m_string) sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = [] deletion_matrix = []
for msa_sequence in sequences: for msa_sequence in sequences:
deletion_vec = [] deletion_vec = []
...@@ -155,7 +183,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -155,7 +183,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase) deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return Msa(sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions)
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
...@@ -168,7 +198,8 @@ def _convert_sto_seq_to_a3m( ...@@ -168,7 +198,8 @@ def _convert_sto_seq_to_a3m(
def convert_stockholm_to_a3m(stockholm_format: str, def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None) -> str: max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
sequences = {} sequences = {}
...@@ -203,18 +234,138 @@ def convert_stockholm_to_a3m(stockholm_format: str, ...@@ -203,18 +234,138 @@ def convert_stockholm_to_a3m(stockholm_format: str,
# Convert sto format to a3m line by line # Convert sto format to a3m line by line
a3m_sequences = {} a3m_sequences = {}
# query_sequence is assumed to be the first sequence if remove_first_row_gaps:
query_sequence = next(iter(sequences.values())) # query_sequence is assumed to be the first sequence
query_non_gaps = [res != '-' for res in query_sequence] query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join( # Dots are optional in a3m format and are commonly removed.
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) out_sequence = sto_sequence.replace('.', '')
if remove_first_row_gaps:
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
a3m_sequences[seqname] = out_sequence
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences) for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
"""Truncates a stockholm file to a maximum number of sequences."""
seqnames = set()
filtered_lines = []
for line in stockholm_msa.splitlines():
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]: regex_pattern: str, line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line) match = re.match(regex_pattern, line)
...@@ -264,8 +415,8 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -264,8 +415,8 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
raise RuntimeError( raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' % 'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2])) (detailed_lines, detailed_lines[2]))
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, (_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x)
neff) = [float(x) for x in match.groups()] for x in match.groups()]
# The next section reads the detailed comparisons. These are in a 'human # The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to # readable' format which has a fixed length. The strategy employed is to
...@@ -362,3 +513,95 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: ...@@ -362,3 +513,95 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name = fields[0] target_name = fields[0]
e_values[target_name] = float(e_value) e_values[target_name] = float(e_value)
return e_values return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6])
def parse_hmmsearch_a3m(query_sequence: str,
a3m_string: str,
skip_first: bool = True) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
...@@ -15,19 +15,22 @@ ...@@ -15,19 +15,22 @@
"""Functions for building the input features for the AlphaFold model.""" """Functions for building the input features for the AlphaFold model."""
import os import os
from typing import Mapping, Optional, Sequence from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
from absl import logging from absl import logging
from alphafold.common import residue_constants from alphafold.common import residue_constants
from alphafold.data import msa_identifiers
from alphafold.data import parsers from alphafold.data import parsers
from alphafold.data import templates from alphafold.data import templates
from alphafold.data.tools import hhblits from alphafold.data.tools import hhblits
from alphafold.data.tools import hhsearch from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
from alphafold.data.tools import jackhmmer from alphafold.data.tools import jackhmmer
import numpy as np import numpy as np
# Internal import (7716). # Internal import (7716).
FeatureDict = Mapping[str, np.ndarray] FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def make_sequence_features( def make_sequence_features(
...@@ -47,55 +50,78 @@ def make_sequence_features( ...@@ -47,55 +50,78 @@ def make_sequence_features(
return features return features
def make_msa_features( def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError('At least one MSA must be provided.') raise ValueError('At least one MSA must be provided.')
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
uniprot_accession_ids = []
species_ids = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
seen_sequences.add(sequence) seen_sequences.add(sequence)
int_msa.append( int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
num_res = len(msas[0][0]) msa.descriptions[sequence_index])
uniprot_accession_ids.append(
identifiers.uniprot_accession_id.encode('utf-8'))
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32) features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array( features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32) [num_alignments] * num_res, dtype=np.int32)
features['msa_uniprot_accession_identifiers'] = np.array(
uniprot_accession_ids, dtype=np.object_)
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features return features
def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool,
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
logging.warning('Reading MSA from file %s', msa_out_path)
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result
class DataPipeline: class DataPipeline:
"""Runs the alignment tools and assembles the input features.""" """Runs the alignment tools and assembles the input features."""
def __init__(self, def __init__(self,
jackhmmer_binary_path: str, jackhmmer_binary_path: str,
hhblits_binary_path: str, hhblits_binary_path: str,
hhsearch_binary_path: str,
uniref90_database_path: str, uniref90_database_path: str,
mgnify_database_path: str, mgnify_database_path: str,
bfd_database_path: Optional[str], bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str], uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str], small_bfd_database_path: Optional[str],
pdb70_database_path: str, template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool, use_small_bfd: bool,
mgnify_max_hits: int = 501, mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000): uniref_max_hits: int = 10000,
"""Constructs a feature dict for a given FASTA file.""" use_precomputed_msas: bool = False):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
...@@ -111,12 +137,11 @@ class DataPipeline: ...@@ -111,12 +137,11 @@ class DataPipeline:
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path) database_path=mgnify_database_path)
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.template_searcher = template_searcher
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path])
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features.""" """Runs alignment tools on the input sequence and creates features."""
...@@ -130,72 +155,68 @@ class DataPipeline: ...@@ -130,72 +155,68 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
input_fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
with open(uniref90_out_path, 'w') as f: jackhmmer_uniref90_result = run_msa_tool(
f.write(jackhmmer_uniref90_result['sto']) self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
'sto', self.use_precomputed_msas)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
with open(mgnify_out_path, 'w') as f: jackhmmer_mgnify_result = run_msa_tool(
f.write(jackhmmer_mgnify_result['sto']) self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
self.use_precomputed_msas)
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f: msa_for_templates = jackhmmer_uniref90_result['sto']
f.write(hhsearch_result) msa_for_templates = parsers.truncate_stockholm_msa(
msa_for_templates, max_sequences=self.uniref_max_hits)
msa_for_templates = parsers.deduplicate_stockholm_msa(
msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates)
if self.template_searcher.input_format == 'sto':
pdb_templates_result = self.template_searcher.query(msa_for_templates)
elif self.template_searcher.input_format == 'a3m':
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
else:
raise ValueError('Unrecognized template input format: '
f'{self.template_searcher.input_format}')
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( pdb_hits_out_path = os.path.join(
jackhmmer_uniref90_result['sto']) msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( with open(pdb_hits_out_path, 'w') as f:
jackhmmer_mgnify_result['sto']) f.write(pdb_templates_result)
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
if self._use_small_bfd: uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
input_fasta_path)[0] mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') pdb_template_hits = self.template_searcher.get_template_hits(
with open(bfd_out_path, 'w') as f: output_string=pdb_templates_result, input_sequence=input_sequence)
f.write(jackhmmer_small_bfd_result['sto'])
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( if self._use_small_bfd:
jackhmmer_small_bfd_result['sto']) bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path,
'sto', self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else: else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(
input_fasta_path)
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f: hhblits_bfd_uniclust_result = run_msa_tool(
f.write(hhblits_bfd_uniclust_result['a3m']) self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
'a3m', self.use_precomputed_msas)
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
hhblits_bfd_uniclust_result['a3m'])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, hits=pdb_template_hits)
query_release_date=None,
hits=hhsearch_hits)
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=input_description, description=input_description,
num_res=num_res) num_res=num_res)
msa_features = make_msa_features( msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
msas=(uniref90_msa, bfd_msa, mgnify_msa),
deletion_matrices=(uniref90_deletion_matrix,
bfd_deletion_matrix,
mgnify_deletion_matrix))
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the features for the AlphaFold multimer model."""
import collections
import contextlib
import copy
import dataclasses
import json
import os
import tempfile
from typing import Mapping, MutableMapping, Sequence
from absl import logging
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import feature_processing
from alphafold.data import msa_pairing
from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.data.tools import jackhmmer
import numpy as np
# Internal import (7716).
@dataclasses.dataclass(frozen=True)
class _FastaChain:
sequence: str
description: str
def _make_chain_id_map(*,
sequences: Sequence[str],
descriptions: Sequence[str],
) -> Mapping[str, _FastaChain]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError('Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')
chain_id_map = {}
for chain_id, sequence, description in zip(
protein.PDB_CHAIN_IDS, sequences, descriptions):
chain_id_map[chain_id] = _FastaChain(
sequence=sequence, description=description)
return chain_id_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: pipeline.FeatureDict,
chain_id: str) -> pipeline.FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict],
) -> MutableMapping[str, pipeline.FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: pipeline.DataPipeline,
jackhmmer_binary_path: str,
uniprot_database_path: str,
max_uniprot_hits: int = 50000,
use_precomputed_msas: bool = False):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
self._uniprot_msa_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path)
self._max_uniprot_hits = max_uniprot_hits
self.use_precomputed_msas = use_precomputed_msas
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
msa_output_dir: str,
is_homomer_or_monomer: bool) -> pipeline.FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{description}\n{sequence}\n'
chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
if not os.path.exists(chain_msa_output_dir):
os.makedirs(chain_msa_output_dir)
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
logging.info('Running monomer pipeline on chain %s: %s',
chain_id, description)
chain_features = self._monomer_data_pipeline.process(
input_fasta_path=chain_fasta_path,
msa_output_dir=chain_msa_output_dir)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
result = pipeline.run_msa_tool(
self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
self.use_precomputed_msas)
msa = parsers.parse_stockholm(result['sto'])
msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_uniprot_accession_identifiers',
'msa_species_identifiers',
)
feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats}
return feats
def process(self,
input_fasta_path: str,
msa_output_dir: str,
is_prokaryote: bool = False) -> pipeline.FeatureDict:
"""Runs alignment tools on the input sequences and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
chain_id_map = _make_chain_id_map(sequences=input_seqs,
descriptions=input_descs)
chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
for chain_id, fasta_chain in chain_id_map.items()}
json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for chain_id, fasta_chain in chain_id_map.items():
if fasta_chain.sequence in sequence_features:
all_chain_features[chain_id] = copy.deepcopy(
sequence_features[fasta_chain.sequence])
continue
chain_features = self._process_single_chain(
chain_id=chain_id,
sequence=fasta_chain.sequence,
description=fasta_chain.description,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=is_homomer_or_monomer)
chain_features = convert_monomer_features(chain_features,
chain_id=chain_id)
all_chain_features[chain_id] = chain_features
sequence_features[fasta_chain.sequence] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features,
is_prokaryote=is_prokaryote,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
"""Functions for getting templates and calculating template features.""" """Functions for getting templates and calculating template features."""
import abc
import dataclasses import dataclasses
import datetime import datetime
import functools
import glob import glob
import os import os
import re import re
...@@ -71,10 +73,6 @@ class DateError(PrefilterError): ...@@ -71,10 +73,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date.""" """An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError): class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small.""" """An error indicating that the hit align ratio to the query was too small."""
...@@ -128,7 +126,6 @@ def _is_after_cutoff( ...@@ -128,7 +126,6 @@ def _is_after_cutoff(
else: else:
# Since this is just a quick prefilter to reduce the number of mmCIF files # Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here. # we need to parse, we don't have to worry about returning True here.
logging.warning('Template structure not in release dates dict: %s', pdb_id)
return False return False
...@@ -177,7 +174,6 @@ def _assess_hhsearch_hit( ...@@ -177,7 +174,6 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime, release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95, max_subsequence_ratio: float = 0.95,
...@@ -190,7 +186,6 @@ def _assess_hhsearch_hit( ...@@ -190,7 +186,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might different from the value in the actual hit since the original pdb might
have become obsolete. have become obsolete.
query_sequence: Amino acid sequence of the query. query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release release_dates: Dictionary mapping pdb codes to their structure release
dates. dates.
release_date_cutoff: Max release date that is valid for this query. release_date_cutoff: Max release date that is valid for this query.
...@@ -202,7 +197,6 @@ def _assess_hhsearch_hit( ...@@ -202,7 +197,6 @@ def _assess_hhsearch_hit(
Raises: Raises:
DateError: If the hit date was after the max allowed date. DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small. AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query. DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short. LengthError: If the hit was too short.
...@@ -222,10 +216,6 @@ def _assess_hhsearch_hit( ...@@ -222,10 +216,6 @@ def _assess_hhsearch_hit(
raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ' raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
f'({release_date_cutoff}).') f'({release_date_cutoff}).')
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError('PDB code identical to Query PDB code.')
if align_ratio <= min_align_ratio: if align_ratio <= min_align_ratio:
raise AlignRatioError('Proportion of residues aligned to query too small. ' raise AlignRatioError('Proportion of residues aligned to query too small. '
f'Align ratio: {align_ratio}.') f'Align ratio: {align_ratio}.')
...@@ -368,8 +358,9 @@ def _realign_pdb_template_to_query( ...@@ -368,8 +358,9 @@ def _realign_pdb_template_to_query(
'protein chain.') 'protein chain.')
try: try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( parsed_a3m = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence])) aligner.align([old_template_sequence, new_template_sequence]))
old_aligned_template, new_aligned_template = parsed_a3m.sequences
except Exception as e: except Exception as e:
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
'Could not align old template %s to template %s (%s_%s). Error: %s' % 'Could not align old template %s to template %s (%s_%s). Error: %s' %
...@@ -472,6 +463,18 @@ def _get_atom_positions( ...@@ -472,6 +463,18 @@ def _get_atom_positions(
pos[residue_constants.atom_order['SD']] = [x, y, z] pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0 mask[residue_constants.atom_order['SD']] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1.
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if (res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_positions[res_index] = pos all_positions[res_index] = pos
all_positions_mask[res_index] = mask all_positions_mask[res_index] = mask
_check_residue_distances( _check_residue_distances(
...@@ -673,9 +676,15 @@ class SingleHitResult: ...@@ -673,9 +676,15 @@ class SingleHitResult:
warning: Optional[str] warning: Optional[str]
@functools.lru_cache(16, typed=False)
def _read_file(path):
with open(path, 'r') as f:
file_data = f.read()
return file_data
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
...@@ -702,14 +711,12 @@ def _process_single_hit( ...@@ -702,14 +711,12 @@ def _process_single_hit(
hit=hit, hit=hit,
hit_pdb_code=hit_pdb_code, hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates, release_dates=release_dates,
release_date_cutoff=max_template_date) release_date_cutoff=max_template_date)
except PrefilterError as e: except PrefilterError as e:
msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
logging.info('%s: %s', query_pdb_code, msg) logging.info(msg)
if strict_error_check and isinstance( if strict_error_check and isinstance(e, (DateError, DuplicateError)):
e, (DateError, PdbIdError, DuplicateError)):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None) return SingleHitResult(features=None, error=msg, warning=None)
...@@ -724,11 +731,10 @@ def _process_single_hit( ...@@ -724,11 +731,10 @@ def _process_single_hit(
template_sequence = hit.hit_sequence.replace('-', '') template_sequence = hit.hit_sequence.replace('-', '')
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.info('Reading PDB entry from %s. Query: %s, template: %s', logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path,
cif_path, query_sequence, template_sequence) query_sequence, template_sequence)
# Fail if we can't find the mmCIF file. # Fail if we can't find the mmCIF file.
with open(cif_path, 'r') as cif_file: cif_string = _read_file(cif_path)
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse( parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string) file_id=hit_pdb_code, mmcif_string=cif_string)
...@@ -742,7 +748,7 @@ def _process_single_hit( ...@@ -742,7 +748,7 @@ def _process_single_hit(
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
else: else:
logging.warning(error) logging.debug(error)
return SingleHitResult(features=None, error=None, warning=None) return SingleHitResult(features=None, error=None, warning=None)
try: try:
...@@ -754,7 +760,10 @@ def _process_single_hit( ...@@ -754,7 +760,10 @@ def _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path) kalign_binary_path=kalign_binary_path)
features['template_sum_probs'] = [hit.sum_probs] if hit.sum_probs is None:
features['template_sum_probs'] = [0]
else:
features['template_sum_probs'] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
...@@ -765,7 +774,7 @@ def _process_single_hit( ...@@ -765,7 +774,7 @@ def _process_single_hit(
TemplateAtomMaskAllZerosError) as e: TemplateAtomMaskAllZerosError) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a # These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings. # problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
'%s, mmCIF parsing errors: %s' '%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors)) str(e), parsing_result.errors))
...@@ -788,8 +797,8 @@ class TemplateSearchResult: ...@@ -788,8 +797,8 @@ class TemplateSearchResult:
warnings: Sequence[str] warnings: Sequence[str]
class TemplateHitFeaturizer: class TemplateHitFeaturizer(abc.ABC):
"""A class for turning hhr hits to template features.""" """An abstract base class for turning template hits to template features."""
def __init__( def __init__(
self, self,
...@@ -850,29 +859,28 @@ class TemplateHitFeaturizer: ...@@ -850,29 +859,28 @@ class TemplateHitFeaturizer:
else: else:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
@abc.abstractmethod
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence."""
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hhsearch to template features."""
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_pdb_code) logging.info('Searching for template for: %s', query_sequence)
template_features = {} template_features = {}
for template_feature_name in TEMPLATE_FEATURES: for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = [] template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0 num_hits = 0
errors = [] errors = []
warnings = [] warnings = []
...@@ -884,10 +892,9 @@ class TemplateHitFeaturizer: ...@@ -884,10 +892,9 @@ class TemplateHitFeaturizer:
result = _process_single_hit( result = _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
mmcif_dir=self._mmcif_dir, mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date, max_template_date=self._max_template_date,
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
...@@ -920,3 +927,84 @@ class TemplateHitFeaturizer: ...@@ -920,3 +927,84 @@ class TemplateHitFeaturizer:
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings) features=template_features, errors=errors, warnings=warnings)
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
"""A class for turning a3m hits from hmmsearch to template features."""
def get_templates(
self,
query_sequence: str,
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_sequence)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
already_seen = set()
errors = []
warnings = []
if not hits or hits[0].sum_probs is None:
sorted_hits = hits
else:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
for hit in sorted_hits:
# We got all the templates we wanted, stop processing hits.
if len(already_seen) >= self._max_hits:
break
result = _process_single_hit(
query_sequence=query_sequence,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.debug('Skipped invalid hit %s, error: %s, warning: %s',
hit.name, result.error, result.warning)
else:
already_seen_key = result.features['template_sequence']
if already_seen_key in already_seen:
continue
# Increment the hit counter, since we got features out of this hit.
already_seen.add(already_seen_key)
for k in template_features:
template_features[k].append(result.features[k])
if already_seen:
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
'template_aatype': np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32),
'template_all_atom_masks': np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32),
'template_all_atom_positions': np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32),
'template_domain_names': np.array([''.encode()], dtype=np.object),
'template_sequence': np.array([''.encode()], dtype=np.object),
'template_sum_probs': np.array([0], dtype=np.float32)
}
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import glob import glob
import os import os
import subprocess import subprocess
from typing import Any, Mapping, Optional, Sequence from typing import Any, List, Mapping, Optional, Sequence
from absl import logging from absl import logging
from alphafold.data.tools import utils from alphafold.data.tools import utils
...@@ -94,9 +94,9 @@ class HHBlits: ...@@ -94,9 +94,9 @@ class HHBlits:
self.p = p self.p = p
self.z = z self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]: def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits.""" """Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m') a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
db_cmd = [] db_cmd = []
...@@ -152,4 +152,4 @@ class HHBlits: ...@@ -152,4 +152,4 @@ class HHBlits:
stderr=stderr, stderr=stderr,
n_iter=self.n_iter, n_iter=self.n_iter,
e_value=self.e_value) e_value=self.e_value)
return raw_output return [raw_output]
...@@ -21,6 +21,7 @@ from typing import Sequence ...@@ -21,6 +21,7 @@ from typing import Sequence
from absl import logging from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import utils from alphafold.data.tools import utils
# Internal import (7716). # Internal import (7716).
...@@ -55,9 +56,17 @@ class HHSearch: ...@@ -55,9 +56,17 @@ class HHSearch:
logging.error('Could not find HHsearch database %s', database_path) logging.error('Could not find HHsearch database %s', database_path)
raise ValueError(f'Could not find HHsearch database {database_path}') raise ValueError(f'Could not find HHsearch database {database_path}')
@property
def output_format(self) -> str:
return 'hhr'
@property
def input_format(self) -> str:
return 'a3m'
def query(self, a3m: str) -> str: def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m') input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr') hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f: with open(input_path, 'w') as f:
...@@ -89,3 +98,10 @@ class HHSearch: ...@@ -89,3 +98,10 @@ class HHSearch:
with open(hhr_path) as f: with open(hhr_path) as f:
hhr = f.read() hhr = f.read()
return hhr return hhr
def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
del input_sequence # Used by hmmseach but not needed for hhsearch.
return parsers.parse_hhr(output_string)
...@@ -98,7 +98,7 @@ class Hmmbuild(object): ...@@ -98,7 +98,7 @@ class Hmmbuild(object):
raise ValueError(f'Invalid model_construction {model_construction} - only' raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.') 'hand and fast supported.')
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa') input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
......
...@@ -19,6 +19,8 @@ import subprocess ...@@ -19,6 +19,8 @@ import subprocess
from typing import Optional, Sequence from typing import Optional, Sequence
from absl import logging from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import hmmbuild
from alphafold.data.tools import utils from alphafold.data.tools import utils
# Internal import (7716). # Internal import (7716).
...@@ -29,12 +31,15 @@ class Hmmsearch(object): ...@@ -29,12 +31,15 @@ class Hmmsearch(object):
def __init__(self, def __init__(self,
*, *,
binary_path: str, binary_path: str,
hmmbuild_binary_path: str,
database_path: str, database_path: str,
flags: Optional[Sequence[str]] = None): flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper. """Initializes the Python hmmsearch wrapper.
Args: Args:
binary_path: The path to the hmmsearch executable. binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format). database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch. flags: List of flags to be used by hmmsearch.
...@@ -42,18 +47,42 @@ class Hmmsearch(object): ...@@ -42,18 +47,42 @@ class Hmmsearch(object):
RuntimeError: If hmmsearch binary not found within the path. RuntimeError: If hmmsearch binary not found within the path.
""" """
self.binary_path = binary_path self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = ['--F1', '0.1',
'--F2', '0.1',
'--F3', '0.1',
'--incE', '100',
'-E', '100',
'--domE', '100',
'--incdomE', '100']
self.flags = flags self.flags = flags
if not os.path.exists(self.database_path): if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path) logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}') raise ValueError(f'Could not find hmmsearch database {database_path}')
def query(self, hmm: str) -> str: @property
def output_format(self) -> str:
return 'sto'
@property
def input_format(self) -> str:
return 'sto'
def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
model_construction='hand')
return self.query_with_hmm(hmm)
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm.""" """Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m') out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f: with open(hmm_input_path, 'w') as f:
f.write(hmm) f.write(hmm)
...@@ -66,7 +95,7 @@ class Hmmsearch(object): ...@@ -66,7 +95,7 @@ class Hmmsearch(object):
if self.flags: if self.flags:
cmd.extend(self.flags) cmd.extend(self.flags)
cmd.extend([ cmd.extend([
'-A', a3m_out_path, '-A', out_path,
hmm_input_path, hmm_input_path,
self.database_path, self.database_path,
]) ])
...@@ -84,7 +113,19 @@ class Hmmsearch(object): ...@@ -84,7 +113,19 @@ class Hmmsearch(object):
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8'))) stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(a3m_out_path) as f: with open(out_path) as f:
a3m_out = f.read() out_msa = f.read()
return out_msa
return a3m_out def get_template_hits(self,
output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(output_string,
remove_first_row_gaps=False)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False)
return template_hits
...@@ -89,7 +89,7 @@ class Jackhmmer: ...@@ -89,7 +89,7 @@ class Jackhmmer:
def _query_chunk(self, input_fasta_path: str, database_path: str def _query_chunk(self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer.""" """Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto') sto_path = os.path.join(query_tmp_dir, 'output.sto')
# The F1/F2/F3 are the expected proportion to pass each of the filtering # The F1/F2/F3 are the expected proportion to pass each of the filtering
...@@ -192,7 +192,10 @@ class Jackhmmer: ...@@ -192,7 +192,10 @@ class Jackhmmer:
# Remove the local copy of the chunk # Remove the local copy of the chunk
os.remove(db_local_chunk(i)) os.remove(db_local_chunk(i))
future = next_future # Do not set next_future for the last chunk so that this works even for
# databases with only 1 chunk.
if i < self.num_streamed_chunks:
future = next_future
if self.streaming_callback: if self.streaming_callback:
self.streaming_callback(i) self.streaming_callback(i)
return chunked_output return chunked_output
...@@ -70,7 +70,7 @@ class Kalign: ...@@ -70,7 +70,7 @@ class Kalign:
raise ValueError('Kalign requires all sequences to be at least 6 ' raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s))) 'residues long. Got %s (%d residues).' % (s, len(s)))
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment