Commit 0bab1bf8 authored by Saran Tunyasuvunakool's avatar Saran Tunyasuvunakool
Browse files

Add a Colab notebook, add reduced BFD, and various other fixes and improvements.

PiperOrigin-RevId: 386228948
parent d26287ea
...@@ -9,7 +9,15 @@ of this document. ...@@ -9,7 +9,15 @@ of this document.
Any publication that discloses findings arising from using this source code or Any publication that discloses findings arising from using this source code or
the model parameters should [cite](#citing-this-work) the the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2). [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2). Please also refer
to the
[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)
for a detailed description of the method.
**You can use a slightly simplified version of AlphaFold with
[this Colab
notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**
or community-supported versions (see below).
![CASP14 predictions](imgs/casp14_predictions.gif) ![CASP14 predictions](imgs/casp14_predictions.gif)
...@@ -39,7 +47,7 @@ The following steps are required in order to run AlphaFold: ...@@ -39,7 +47,7 @@ The following steps are required in order to run AlphaFold:
### Genetic databases ### Genetic databases
This step requires `rsync` and `aria2c` to be installed on your machine. This step requires `aria2c` to be installed on your machine.
AlphaFold needs multiple genetic (sequence) databases to run: AlphaFold needs multiple genetic (sequence) databases to run:
...@@ -51,21 +59,43 @@ AlphaFold needs multiple genetic (sequence) databases to run: ...@@ -51,21 +59,43 @@ AlphaFold needs multiple genetic (sequence) databases to run:
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format). * [PDB](https://www.rcsb.org/) (structures in the mmCIF format).
We provide a script `scripts/download_all_data.sh` that can be used to download We provide a script `scripts/download_all_data.sh` that can be used to download
and set up all of these databases. This should take 8–12 hours. and set up all of these databases:
* Default:
:ledger: **Note: The total download size is around 428 GB and the total size ```bash
when unzipped is 2.2 TB. Please make sure you have a large enough hard drive scripts/download_all_data.sh <DOWNLOAD_DIR>
space, bandwidth and time to download.** ```
will download the full databases.
* With `reduced_dbs`:
```bash
scripts/download_all_data.sh <DOWNLOAD_DIR> reduced_dbs
```
will download a reduced version of the databases to be used with the
`reduced_dbs` preset.
We don't provide exactly the versions used in CASP14 -- see the [note on
reproducibility](#note-on-reproducibility). Some of the databases are mirrored
for speed, see [mirrored databases](#mirrored-databases).
:ledger: **Note: The total download size for the full databases is around 415 GB
and the total size when unzipped is 2.2 TB. Please make sure you have a large
enough hard drive space, bandwidth and time to download. We recommend using an
SSD for better genetic search performance.**
This script will also download the model parameter files. Once the script has This script will also download the model parameter files. Once the script has
finished, you should have the following directory structure: finished, you should have the following directory structure:
``` ```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 428 GB) $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB)
bfd/ # ~ 1.8 TB (download: 271.6 GB) bfd/ # ~ 1.7 TB (download: 271.6 GB)
# 6 files. # 6 files.
mgnify/ # ~ 64 GB (download: 32.9 GB) mgnify/ # ~ 64 GB (download: 32.9 GB)
mgy_clusters.fa mgy_clusters_2018_08.fa
params/ # ~ 3.5 GB (download: 3.5 GB) params/ # ~ 3.5 GB (download: 3.5 GB)
# 5 CASP14 models, # 5 CASP14 models,
# 5 pTM models, # 5 pTM models,
...@@ -77,13 +107,18 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 428 GB) ...@@ -77,13 +107,18 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 428 GB)
mmcif_files/ mmcif_files/
# About 180,000 .cif files. # About 180,000 .cif files.
obsolete.dat obsolete.dat
uniclust30/ # ~ 87 GB (download: 24.9 GB) small_fbd/ # ~ 17 GB (download: 9.6 GB)
bfd-first_non_consensus_sequences.fasta
uniclust30/ # ~ 86 GB (download: 24.9 GB)
uniclust30_2018_08/ uniclust30_2018_08/
# 13 files. # 13 files.
uniref90/ # ~ 59 GB (download: 29.7 GB) uniref90/ # ~ 58 GB (download: 29.7 GB)
uniref90.fasta uniref90.fasta
``` ```
`bfd/` is only downloaded if you download the full databasees, and `small_bfd/`
is only downloaded if you download the reduced databases.
### Model parameters ### Model parameters
While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold
...@@ -149,16 +184,20 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional ...@@ -149,16 +184,20 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration) [GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details. for more details.
1. You can control AlphaFold speed / quality tradeoff by adding either 1. You can control AlphaFold speed / quality tradeoff by adding
`--preset=full_dbs` or `--preset=casp14` to the run command. We provide the `--preset=reduced_dbs`, `--preset=full_dbs` or `--preset=casp14` to the run
following presets: command. We provide the following presets:
* **casp14**: This preset uses the same settings as were used in CASP14. * **reduced_dbs**: This preset is optimized for speed and lower hardware
It runs with all genetic databases and with 8 ensemblings. requirements. It runs with a reduced version of the BFD database and
with no ensembling. It requires 8 CPU cores (vCPUs), 8 GB of RAM, and
600 GB of disk space.
* **full_dbs**: The model in this preset is 8 times faster than the * **full_dbs**: The model in this preset is 8 times faster than the
`casp14` preset with a very minor quality drop (-0.1 average GDT drop on `casp14` preset with a very minor quality drop (-0.1 average GDT drop on
CASP14 domains). It runs with all genetic databases and with no CASP14 domains). It runs with all genetic databases and with no
ensembling. ensembling.
* **casp14**: This preset uses the same settings as were used in CASP14.
It runs with all genetic databases and with 8 ensemblings.
Running the command above with the `casp14` preset would look like this: Running the command above with the `casp14` preset would look like this:
...@@ -174,7 +213,7 @@ structures, raw model outputs, prediction metadata, and section timings. The ...@@ -174,7 +213,7 @@ structures, raw model outputs, prediction metadata, and section timings. The
`output_dir` directory will have the following structure: `output_dir` directory will have the following structure:
``` ```
output_dir/ <target_name>/
features.pkl features.pkl
ranked_{0,1,2,3,4}.pdb ranked_{0,1,2,3,4}.pdb
ranking_debug.json ranking_debug.json
...@@ -190,20 +229,20 @@ output_dir/ ...@@ -190,20 +229,20 @@ output_dir/
The contents of each output file are as follows: The contents of each output file are as follows:
* `features.pkl` – A `pickle` file containing the input feature Numpy arrays * `features.pkl` – A `pickle` file containing the input feature NumPy arrays
used by the models to produce the structures. used by the models to produce the structures.
* `unrelaxed_model_*.pdb` – A PDB format text file containing the predicted * `unrelaxed_model_*.pdb` – A PDB format text file containing the predicted
structure, exactly as outputted by the model. structure, exactly as outputted by the model.
* `relaxed_model_*.pdb` – A PDB format text file containing the predicted * `relaxed_model_*.pdb` – A PDB format text file containing the predicted
structure, after performing an Amber relaxation procedure on the unrelaxed structure, after performing an Amber relaxation procedure on the unrelaxed
structure prediction, see Jumper et al. 2021, Suppl. Methods 1.8.6 for structure prediction (see Jumper et al. 2021, Suppl. Methods 1.8.6 for
details. details).
* `ranked_*.pdb` – A PDB format text file containing the relaxed predicted * `ranked_*.pdb` – A PDB format text file containing the relaxed predicted
structures, after reordering by model confidence. Here `ranked_0.pdb` should structures, after reordering by model confidence. Here `ranked_0.pdb` should
contain the prediction with the highest confidence, and `ranked_4.pdb` the contain the prediction with the highest confidence, and `ranked_4.pdb` the
prediction with the lowest confidence. To rank model confidence, we use prediction with the lowest confidence. To rank model confidence, we use
predicted LDDT (pLDDT), see Jumper et al. 2021, Suppl. Methods 1.9.6 for predicted LDDT (pLDDT) scores (see Jumper et al. 2021, Suppl. Methods 1.9.6
details. for details).
* `ranking_debug.json` – A JSON format text file containing the pLDDT values * `ranking_debug.json` – A JSON format text file containing the pLDDT values
used to perform the model ranking, and a mapping back to the original model used to perform the model ranking, and a mapping back to the original model
names. names.
...@@ -212,10 +251,27 @@ The contents of each output file are as follows: ...@@ -212,10 +251,27 @@ The contents of each output file are as follows:
* `msas/` - A directory containing the files describing the various genetic * `msas/` - A directory containing the files describing the various genetic
tool hits that were used to construct the input MSA. tool hits that were used to construct the input MSA.
* `result_model_*.pkl` – A `pickle` file containing a nested dictionary of the * `result_model_*.pkl` – A `pickle` file containing a nested dictionary of the
various Numpy arrays directly produced by the model. In addition to the various NumPy arrays directly produced by the model. In addition to the
output of the structure module, this includes auxiliary outputs such as output of the structure module, this includes auxiliary outputs such as:
distograms and pLDDT scores. If using the pTM models then the pTM logits
will also be contained in this file. * Distograms (`distogram/logits` contains a NumPy array of shape [N_res,
N_res, N_bins] and `distogram/bin_edges` contains the definition of the
bins).
* Per-residue pLDDT scores (`plddt` contains a NumPy array of shape
[N_res] with the range of possible values from `0` to `100`, where `100`
means most confident). This can serve to identify sequence regions
predicted with high confidence or as an overall per-target confidence
score when averaged across residues.
* Present only if using pTM models: predicted TM-score (`ptm` field
contains a scalar). As a predictor of a global superposition metric,
this score is designed to also assess whether the model is confident in
the overall domain packing.
* Present only if using pTM models: predicted pairwise aligned errors
(`predicted_aligned_error` contains a NumPy array of shape [N_res,
N_res] with the range of possible values from `0` to
`max_predicted_aligned_error`, where `0` means most confident). This can
serve for a visualisation of domain packing confidence within the
structure.
This code has been tested to match mean top-1 accuracy on a CASP14 test set with This code has been tested to match mean top-1 accuracy on a CASP14 test set with
pLDDT ranking over 5 model predictions (some CASP targets were run with earlier pLDDT ranking over 5 model predictions (some CASP targets were run with earlier
...@@ -284,6 +340,17 @@ If you use the code or data in this package, please cite: ...@@ -284,6 +340,17 @@ If you use the code or data in this package, please cite:
} }
``` ```
## Community contributions
Colab notebooks provided by the community (please note that these notebooks may
vary from our full AlphaFold system and we did not validate their accuracy):
* The [ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb)
by Martin Steinegger, Sergey Ovchinnikov and Milot Mirdita, which uses an
API hosted at the Södinglab based on the MMseqs2 server [(Mirdita et al.
2019, Bioinformatics)](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)
for the multiple sequence alignment creation.
## Acknowledgements ## Acknowledgements
AlphaFold communicates with and/or references the following separate libraries AlphaFold communicates with and/or references the following separate libraries
...@@ -292,6 +359,7 @@ and packages: ...@@ -292,6 +359,7 @@ and packages:
* [Abseil](https://github.com/abseil/abseil-py) * [Abseil](https://github.com/abseil/abseil-py)
* [Biopython](https://biopython.org) * [Biopython](https://biopython.org)
* [Chex](https://github.com/deepmind/chex) * [Chex](https://github.com/deepmind/chex)
* [Colab](https://research.google.com/colaboratory/)
* [Docker](https://www.docker.com) * [Docker](https://www.docker.com)
* [HH Suite](https://github.com/soedinglab/hh-suite) * [HH Suite](https://github.com/soedinglab/hh-suite)
* [HMMER Suite](http://eddylab.org/software/hmmer) * [HMMER Suite](http://eddylab.org/software/hmmer)
...@@ -299,18 +367,20 @@ and packages: ...@@ -299,18 +367,20 @@ and packages:
* [Immutabledict](https://github.com/corenting/immutabledict) * [Immutabledict](https://github.com/corenting/immutabledict)
* [JAX](https://github.com/google/jax/) * [JAX](https://github.com/google/jax/)
* [Kalign](https://msa.sbc.su.se/cgi-bin/msa.cgi) * [Kalign](https://msa.sbc.su.se/cgi-bin/msa.cgi)
* [matplotlib](https://matplotlib.org/)
* [ML Collections](https://github.com/google/ml_collections) * [ML Collections](https://github.com/google/ml_collections)
* [NumPy](https://numpy.org) * [NumPy](https://numpy.org)
* [OpenMM](https://github.com/openmm/openmm) * [OpenMM](https://github.com/openmm/openmm)
* [OpenStructure](https://openstructure.org) * [OpenStructure](https://openstructure.org)
* [pymol3d](https://github.com/avirshup/py3dmol)
* [SciPy](https://scipy.org) * [SciPy](https://scipy.org)
* [Sonnet](https://github.com/deepmind/sonnet) * [Sonnet](https://github.com/deepmind/sonnet)
* [TensorFlow](https://github.com/tensorflow/tensorflow) * [TensorFlow](https://github.com/tensorflow/tensorflow)
* [Tree](https://github.com/deepmind/tree) * [Tree](https://github.com/deepmind/tree)
* [tqdm](https://github.com/tqdm/tqdm)
We thank all their contributors and maintainers! We thank all their contributors and maintainers!
## License and Disclaimer ## License and Disclaimer
This is not an officially supported Google product. This is not an officially supported Google product.
...@@ -349,3 +419,10 @@ before use. ...@@ -349,3 +419,10 @@ before use.
The following databases have been mirrored by DeepMind, and are available with reference to the following: The following databases have been mirrored by DeepMind, and are available with reference to the following:
* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/). * [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and Söding J., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/). See the Methods section of the [AlphaFold proteome paper]
(https://www.nature.com/articles/s41586-021-03828-1) for details.
* [Uniclust30: v2018_08](http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/) (unmodified), by Mirdita M. et al., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
* [MGnify: v2018_12](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/current_release/README.txt) (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
...@@ -67,7 +67,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -67,7 +67,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
A new `Protein` parsed from the pdb contents. A new `Protein` parsed from the pdb contents.
""" """
pdb_fh = io.StringIO(pdb_str) pdb_fh = io.StringIO(pdb_str)
parser = PDBParser() parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh) structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models()) models = list(structure.get_models())
if len(models) != 1: if len(models) != 1:
...@@ -207,22 +207,25 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray: ...@@ -207,22 +207,25 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray:
return residue_constants.STANDARD_ATOM_MASK[prot.aatype] return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(features: FeatureDict, result: ModelOutput) -> Protein: def from_prediction(features: FeatureDict, result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
"""Assembles a protein from a prediction. """Assembles a protein from a prediction.
Args: Args:
features: Dictionary holding model inputs. features: Dictionary holding model inputs.
result: Dictionary holding model outputs. result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns: Returns:
A protein instance. A protein instance.
""" """
fold_output = result['structure_module'] fold_output = result['structure_module']
dist_per_residue = np.zeros_like(fold_output['final_atom_mask']) if b_factors is None:
b_factors = np.zeros_like(fold_output['final_atom_mask'])
return Protein( return Protein(
aatype=features['aatype'][0], aatype=features['aatype'][0],
atom_positions=fold_output['final_atom_positions'], atom_positions=fold_output['final_atom_positions'],
atom_mask=fold_output['final_atom_mask'], atom_mask=fold_output['final_atom_mask'],
residue_index=features['residue_index'][0] + 1, residue_index=features['residue_index'][0] + 1,
b_factors=dist_per_residue) b_factors=b_factors)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import collections import collections
import re import re
import string import string
from typing import Iterable, List, Optional, Sequence, Tuple from typing import Iterable, List, Optional, Sequence, Tuple, Dict
import dataclasses import dataclasses
...@@ -24,23 +24,14 @@ DeletionMatrix = Sequence[Sequence[int]] ...@@ -24,23 +24,14 @@ DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class HhrHit: class TemplateHit:
"""Class representing a hit in an hhr file.""" """Class representing a template hit."""
index: int index: int
name: str name: str
prob_true: float
e_value: float
score: float
aligned_cols: int aligned_cols: int
identity: float
similarity: float
sum_probs: float sum_probs: float
neff: float
query: str query: str
hit_sequence: str hit_sequence: str
hit_dssp: str
column_score_code: str
confidence_scores: str
indices_query: List[int] indices_query: List[int]
indices_hit: List[int] indices_hit: List[int]
...@@ -75,7 +66,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -75,7 +66,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
def parse_stockholm( def parse_stockholm(
stockholm_string: str) -> Tuple[Sequence[str], DeletionMatrix]: stockholm_string: str
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -89,6 +81,8 @@ def parse_stockholm( ...@@ -89,6 +81,8 @@ def parse_stockholm(
* The deletion matrix for the alignment as a list of lists. The element * The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j. the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
""" """
name_to_sequence = collections.OrderedDict() name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines(): for line in stockholm_string.splitlines():
...@@ -128,7 +122,7 @@ def parse_stockholm( ...@@ -128,7 +122,7 @@ def parse_stockholm(
deletion_count = 0 deletion_count = 0
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
return msa, deletion_matrix return msa, deletion_matrix, list(name_to_sequence.keys())
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
...@@ -242,7 +236,7 @@ def _update_hhr_residue_indices_list( ...@@ -242,7 +236,7 @@ def _update_hhr_residue_indices_list(
counter += 1 counter += 1
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit. """Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch. This works on .hhr files generated from both HHBlits and HHSearch.
...@@ -271,7 +265,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: ...@@ -271,7 +265,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
raise RuntimeError( raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' % 'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2])) (detailed_lines, detailed_lines[2]))
(prob_true, e_value, score, aligned_cols, identity, similarity, sum_probs, (prob_true, e_value, _, aligned_cols, _, _, sum_probs,
neff) = [float(x) for x in match.groups()] neff) = [float(x) for x in match.groups()]
# The next section reads the detailed comparisons. These are in a 'human # The next section reads the detailed comparisons. These are in a 'human
...@@ -280,9 +274,6 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: ...@@ -280,9 +274,6 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
# that with a regexp in order to deduce the fixed length used for that block. # that with a regexp in order to deduce the fixed length used for that block.
query = '' query = ''
hit_sequence = '' hit_sequence = ''
hit_dssp = ''
column_score_code = ''
confidence_scores = ''
indices_query = [] indices_query = []
indices_hit = [] indices_hit = []
length_block = None length_block = None
...@@ -312,16 +303,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: ...@@ -312,16 +303,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
_update_hhr_residue_indices_list(delta_query, start, indices_query) _update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '): elif line.startswith('T '):
# Parse the hit dssp line.
if line.startswith('T ss_dssp'):
# T ss_dssp hit_dssp
patt = r'T ss_dssp[\t ]*([A-Z-]*)'
groups = _get_hhr_line_regex_groups(patt, line)
assert len(groups[0]) == length_block
hit_dssp += groups[0]
# Parse the hit sequence. # Parse the hit sequence.
elif (not line.startswith('T ss_pred') and if (not line.startswith('T ss_dssp') and
not line.startswith('T ss_pred') and
not line.startswith('T Consensus')): not line.startswith('T Consensus')):
# Thus the first 17 characters must be 'T <hit_name> ', and we can # Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that. # parse everything after that.
...@@ -336,38 +320,19 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: ...@@ -336,38 +320,19 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
hit_sequence += delta_hit_sequence hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) _update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit)
# Parse the column score line. return TemplateHit(
elif line.startswith(' ' * 22):
assert length_block
column_score_code += line[22:length_block + 22]
# Update confidence score.
elif line.startswith('Confidence'):
assert length_block
confidence_scores += line[22:length_block + 22]
return HhrHit(
index=number_of_hit, index=number_of_hit,
name=name_hit, name=name_hit,
prob_true=prob_true,
e_value=e_value,
score=score,
aligned_cols=int(aligned_cols), aligned_cols=int(aligned_cols),
identity=identity,
similarity=similarity,
sum_probs=sum_probs, sum_probs=sum_probs,
neff=neff,
query=query, query=query,
hit_sequence=hit_sequence, hit_sequence=hit_sequence,
hit_dssp=hit_dssp,
column_score_code=column_score_code,
confidence_scores=confidence_scores,
indices_query=indices_query, indices_query=indices_query,
indices_hit=indices_hit, indices_hit=indices_hit,
) )
def parse_hhr(hhr_string: str) -> Sequence[HhrHit]: def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file.""" """Parses the content of an entire HHR file."""
lines = hhr_string.splitlines() lines = hhr_string.splitlines()
...@@ -383,3 +348,18 @@ def parse_hhr(hhr_string: str) -> Sequence[HhrHit]: ...@@ -383,3 +348,18 @@ def parse_hhr(hhr_string: str) -> Sequence[HhrHit]:
for i in range(len(block_starts) - 1): for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
return hits return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Functions for building the input features for the AlphaFold model.""" """Functions for building the input features for the AlphaFold model."""
import os import os
from typing import Mapping, Sequence from typing import Mapping, Optional, Sequence
import numpy as np import numpy as np
...@@ -88,16 +88,24 @@ class DataPipeline: ...@@ -88,16 +88,24 @@ class DataPipeline:
hhsearch_binary_path: str, hhsearch_binary_path: str,
uniref90_database_path: str, uniref90_database_path: str,
mgnify_database_path: str, mgnify_database_path: str,
bfd_database_path: str, bfd_database_path: Optional[str],
uniclust30_database_path: str, uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
pdb70_database_path: str, pdb70_database_path: str,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501, mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000): uniref_max_hits: int = 10000):
"""Constructs a feature dict for a given FASTA file.""" """Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path) database_path=uniref90_database_path)
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path, binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path]) databases=[bfd_database_path, uniclust30_database_path])
...@@ -124,9 +132,9 @@ class DataPipeline: ...@@ -124,9 +132,9 @@ class DataPipeline:
num_res = len(input_sequence) num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
input_fasta_path) input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
input_fasta_path) input_fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits)
...@@ -140,14 +148,25 @@ class DataPipeline: ...@@ -140,14 +148,25 @@ class DataPipeline:
with open(mgnify_out_path, 'w') as f: with open(mgnify_out_path, 'w') as f:
f.write(jackhmmer_mgnify_result['sto']) f.write(jackhmmer_mgnify_result['sto'])
uniref90_msa, uniref90_deletion_matrix = parsers.parse_stockholm( uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto']) jackhmmer_uniref90_result['sto'])
mgnify_msa, mgnify_deletion_matrix = parsers.parse_stockholm( mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_mgnify_result['sto']) jackhmmer_mgnify_result['sto'])
hhsearch_hits = parsers.parse_hhr(hhsearch_result) hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits] mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
input_fasta_path)[0]
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto'])
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(
input_fasta_path) input_fasta_path)
...@@ -162,7 +181,7 @@ class DataPipeline: ...@@ -162,7 +181,7 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hhr_hits=hhsearch_hits) hits=hhsearch_hits)
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
......
...@@ -93,19 +93,12 @@ TEMPLATE_FEATURES = { ...@@ -93,19 +93,12 @@ TEMPLATE_FEATURES = {
'template_all_atom_masks': np.float32, 'template_all_atom_masks': np.float32,
'template_all_atom_positions': np.float32, 'template_all_atom_positions': np.float32,
'template_domain_names': np.object, 'template_domain_names': np.object,
'template_e_value': np.float32,
'template_neff': np.float32,
'template_prob_true': np.float32,
'template_release_date': np.object,
'template_score': np.float32,
'template_similarity': np.float32,
'template_sequence': np.object, 'template_sequence': np.object,
'template_sum_probs': np.float32, 'template_sum_probs': np.float32,
'template_confidence_scores': np.int64
} }
def _get_pdb_id_and_chain(hit: parsers.HhrHit) -> Tuple[str, str]: def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit.""" """Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
...@@ -175,7 +168,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: ...@@ -175,7 +168,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
def _assess_hhsearch_hit( def _assess_hhsearch_hit(
hit: parsers.HhrHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
...@@ -487,7 +480,6 @@ def _extract_template_features( ...@@ -487,7 +480,6 @@ def _extract_template_features(
template_sequence: str, template_sequence: str,
query_sequence: str, query_sequence: str,
template_chain_id: str, template_chain_id: str,
confidence_scores: str,
kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]: kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query. """Parses atom positions in the target structure and aligns with the query.
...@@ -495,21 +487,6 @@ def _extract_template_features( ...@@ -495,21 +487,6 @@ def _extract_template_features(
with their corresponding residue in the query sequence, according to the with their corresponding residue in the query sequence, according to the
alignment mapping provided. alignment mapping provided.
Note that we only extract at most 500 templates because of HHSearch settings.
We set missing/invalid confidence scores to the default value of -1.
Note: We now have 4 types of confidence scores:
1. Valid scores
2. Invalid scores of residues not in both the query sequence and template
sequence
3. Missing scores because we don't have the secondary structure, and HHAlign
doesn't produce the posterior probabilities in this case.
4. Missing scores because of a different template sequence in PDB70,
invalidating the previously computed confidence scores. (Though in theory
HHAlign can be run on these to recompute the correct confidence scores).
We handle invalid and missing scores by setting them to -1, but consider
adding masks for the different types.
Args: Args:
mmcif_object: mmcif_parsing.MmcifObject representing the template. mmcif_object: mmcif_parsing.MmcifObject representing the template.
pdb_id: PDB code for the template. pdb_id: PDB code for the template.
...@@ -521,11 +498,6 @@ def _extract_template_features( ...@@ -521,11 +498,6 @@ def _extract_template_features(
protein. protein.
template_chain_id: String ID describing which chain in the structure proto template_chain_id: String ID describing which chain in the structure proto
should be used. should be used.
confidence_scores: String containing per-residue confidence scores, where
each character represents the *TRUNCATED* posterior probability that the
corresponding template residue is correctly aligned with the query
residue, given the database match is correct (0 corresponds approximately
to 0-10%, 9 to 90-100%).
kalign_binary_path: The path to a kalign executable used for template kalign_binary_path: The path to a kalign executable used for template
realignment. realignment.
...@@ -577,8 +549,6 @@ def _extract_template_features( ...@@ -577,8 +549,6 @@ def _extract_template_features(
template_sequence = seqres template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence. # No mapping offset, the query is aligned to the actual sequence.
mapping_offset = 0 mapping_offset = 0
# Confidence scores were based on the previous sequence, so they are invalid
confidence_scores = None
try: try:
# Essentially set to infinity - we don't want to reject templates unless # Essentially set to infinity - we don't want to reject templates unless
...@@ -594,7 +564,6 @@ def _extract_template_features( ...@@ -594,7 +564,6 @@ def _extract_template_features(
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = [] output_templates_sequence = []
output_confidence_scores = []
templates_all_atom_positions = [] templates_all_atom_positions = []
templates_all_atom_masks = [] templates_all_atom_masks = []
...@@ -604,15 +573,12 @@ def _extract_template_features( ...@@ -604,15 +573,12 @@ def _extract_template_features(
np.zeros((residue_constants.atom_type_num, 3))) np.zeros((residue_constants.atom_type_num, 3)))
templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num)) templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num))
output_templates_sequence.append('-') output_templates_sequence.append('-')
output_confidence_scores.append(-1)
for k, v in mapping.items(): for k, v in mapping.items():
template_index = v + mapping_offset template_index = v + mapping_offset
templates_all_atom_positions[k] = all_atom_positions[template_index][0] templates_all_atom_positions[k] = all_atom_positions[template_index][0]
templates_all_atom_masks[k] = all_atom_masks[template_index][0] templates_all_atom_masks[k] = all_atom_masks[template_index][0]
output_templates_sequence[k] = template_sequence[v] output_templates_sequence[k] = template_sequence[v]
if confidence_scores and confidence_scores[v] != ' ':
output_confidence_scores[k] = int(confidence_scores[v])
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5: if np.sum(templates_all_atom_masks) < 5:
...@@ -627,13 +593,13 @@ def _extract_template_features( ...@@ -627,13 +593,13 @@ def _extract_template_features(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)
return ( return (
{'template_all_atom_positions': np.array(templates_all_atom_positions), {
'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_masks': np.array(templates_all_atom_masks), 'template_all_atom_masks': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(), 'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype), 'template_aatype': np.array(templates_aatype),
'template_confidence_scores': np.array(output_confidence_scores),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
'template_release_date': mmcif_object.header['release_date'].encode()}, },
warning) warning)
...@@ -704,7 +670,7 @@ class SingleHitResult: ...@@ -704,7 +670,7 @@ class SingleHitResult:
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
hit: parsers.HhrHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
...@@ -745,9 +711,6 @@ def _process_single_hit( ...@@ -745,9 +711,6 @@ def _process_single_hit(
# The mapping is from the query to the actual hit sequence, so we need to # The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score). # remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace('-', '') template_sequence = hit.hit_sequence.replace('-', '')
confidence_scores = ''.join(
[cs for t, cs in zip(hit.hit_sequence, hit.confidence_scores)
if t != '-'])
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.info('Reading PDB entry from %s. Query: %s, template: %s', logging.info('Reading PDB entry from %s. Query: %s, template: %s',
...@@ -779,14 +742,8 @@ def _process_single_hit( ...@@ -779,14 +742,8 @@ def _process_single_hit(
template_sequence=template_sequence, template_sequence=template_sequence,
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
confidence_scores=confidence_scores,
kalign_binary_path=kalign_binary_path) kalign_binary_path=kalign_binary_path)
features['template_e_value'] = [hit.e_value]
features['template_sum_probs'] = [hit.sum_probs] features['template_sum_probs'] = [hit.sum_probs]
features['template_prob_true'] = [hit.prob_true]
features['template_score'] = [hit.score]
features['template_neff'] = [hit.neff]
features['template_similarity'] = [hit.similarity]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
...@@ -887,7 +844,7 @@ class TemplateHitFeaturizer: ...@@ -887,7 +844,7 @@ class TemplateHitFeaturizer:
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime], query_release_date: Optional[datetime.datetime],
hhr_hits: Sequence[parsers.HhrHit]) -> TemplateSearchResult: hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_pdb_code) logging.info('Searching for template for: %s', query_pdb_code)
...@@ -909,8 +866,8 @@ class TemplateHitFeaturizer: ...@@ -909,8 +866,8 @@ class TemplateHitFeaturizer:
errors = [] errors = []
warnings = [] warnings = []
for hit in sorted(hhr_hits, key=lambda x: x.sum_probs, reverse=True): for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
# We got all the templates we wanted, stop processing HHSearch hits. # We got all the templates we wanted, stop processing hits.
if num_hits >= self._max_hits: if num_hits >= self._max_hits:
break break
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python wrappers for third party tools."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import subprocess
from absl import logging
# Internal import (7716).
from alphafold.data.tools import utils
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(self,
*,
binary_path: str,
singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_query, 'w') as f:
f.write(msa)
cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])
logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
if retcode:
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()
return hmm
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
from typing import Optional, Sequence
from absl import logging
# Internal import (7716).
from alphafold.data.tools import utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(self,
*,
binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.database_path = database_path
self.flags = flags
if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
def query(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu', '8'
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A', a3m_out_path,
hmm_input_path,
self.database_path,
])
logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr.decode('utf-8')))
with open(a3m_out_path) as f:
a3m_out = f.read()
return a3m_out
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
"""Library to run Jackhmmer from Python.""" """Library to run Jackhmmer from Python."""
from concurrent import futures
import glob
import os import os
import subprocess import subprocess
from typing import Any, Mapping, Optional from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from absl import logging from absl import logging
...@@ -40,7 +43,9 @@ class Jackhmmer: ...@@ -40,7 +43,9 @@ class Jackhmmer:
filter_f2: float = 0.00005, filter_f2: float = 0.00005,
filter_f3: float = 0.0000005, filter_f3: float = 0.0000005,
incdom_e: Optional[float] = None, incdom_e: Optional[float] = None,
dom_e: Optional[float] = None): dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None):
"""Initializes the Python Jackhmmer wrapper. """Initializes the Python Jackhmmer wrapper.
Args: Args:
...@@ -57,11 +62,15 @@ class Jackhmmer: ...@@ -57,11 +62,15 @@ class Jackhmmer:
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round. round.
dom_e: Domain e-value criteria for inclusion in tblout. dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
""" """
self.binary_path = binary_path self.binary_path = binary_path
self.database_path = database_path self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks
if not os.path.exists(self.database_path): if not os.path.exists(self.database_path) and num_streamed_chunks is None:
logging.error('Could not find Jackhmmer database %s', database_path) logging.error('Could not find Jackhmmer database %s', database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}') raise ValueError(f'Could not find Jackhmmer database {database_path}')
...@@ -75,9 +84,11 @@ class Jackhmmer: ...@@ -75,9 +84,11 @@ class Jackhmmer:
self.incdom_e = incdom_e self.incdom_e = incdom_e
self.dom_e = dom_e self.dom_e = dom_e
self.get_tblout = get_tblout self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
def query(self, input_fasta_path: str) -> Mapping[str, Any]: def _query_chunk(self, input_fasta_path: str, database_path: str
"""Queries the database using Jackhmmer.""" ) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto') sto_path = os.path.join(query_tmp_dir, 'output.sto')
...@@ -114,13 +125,13 @@ class Jackhmmer: ...@@ -114,13 +125,13 @@ class Jackhmmer:
cmd_flags.extend(['--incdomE', str(self.incdom_e)]) cmd_flags.extend(['--incdomE', str(self.incdom_e)])
cmd = [self.binary_path] + cmd_flags + [input_fasta_path, cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
self.database_path] database_path]
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen( process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing( with utils.timing(
f'Jackhmmer ({os.path.basename(self.database_path)}) query'): f'Jackhmmer ({os.path.basename(database_path)}) query'):
_, stderr = process.communicate() _, stderr = process.communicate()
retcode = process.wait() retcode = process.wait()
...@@ -145,3 +156,43 @@ class Jackhmmer: ...@@ -145,3 +156,43 @@ class Jackhmmer:
e_value=self.e_value) e_value=self.e_value)
return raw_output return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')):
try:
os.remove(f)
except OSError:
print(f'OSError while deleting {f}')
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)))
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
...@@ -492,7 +492,7 @@ class StructureModule(hk.Module): ...@@ -492,7 +492,7 @@ class StructureModule(hk.Module):
is_training=is_training, is_training=is_training,
safe_key=safe_key) safe_key=safe_key)
representations['structure_module'] = output['act'] ret['representations'] = {'structure_module': output['act']}
ret['traj'] = output['affine'] * jnp.array([1.] * 4 + ret['traj'] = output['affine'] * jnp.array([1.] * 4 +
[c.position_scale] * 3) [c.position_scale] * 3)
...@@ -514,7 +514,8 @@ class StructureModule(hk.Module): ...@@ -514,7 +514,8 @@ class StructureModule(hk.Module):
if self.compute_loss: if self.compute_loss:
return ret return ret
else: else:
no_loss_features = ['final_atom_positions', 'final_atom_mask'] no_loss_features = ['final_atom_positions', 'final_atom_mask',
'representations']
no_loss_ret = {k: ret[k] for k in no_loss_features} no_loss_ret = {k: ret[k] for k in no_loss_features}
return no_loss_ret return no_loss_ret
......
...@@ -237,6 +237,10 @@ class AlphaFoldIteration(hk.Module): ...@@ -237,6 +237,10 @@ class AlphaFoldIteration(hk.Module):
continue continue
else: else:
ret[name] = module(representations, batch, is_training) ret[name] = module(representations, batch, is_training)
if 'representations' in ret[name]:
# Extra representations from the head. Used by the structure module
# to provide activations for the PredictedLDDTHead.
representations.update(ret[name].pop('representations'))
if compute_loss: if compute_loss:
total_loss += loss(module, head_config, ret, name) total_loss += loss(module, head_config, ret, name)
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model TensorFlow code."""
...@@ -146,13 +146,13 @@ def process_tensors_from_config(tensors, data_config): ...@@ -146,13 +146,13 @@ def process_tensors_from_config(tensors, data_config):
num_ensemble *= data_config.common.num_recycle + 1 num_ensemble *= data_config.common.num_recycle + 1
if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
dtype = tree.map_structure(lambda x: x.dtype, fn_output_signature = tree.map_structure(
tensors_0) tf.TensorSpec.from_tensor, tensors_0)
tensors = tf.map_fn( tensors = tf.map_fn(
lambda x: wrap_ensemble_fn(tensors, x), lambda x: wrap_ensemble_fn(tensors, x),
tf.range(num_ensemble), tf.range(num_ensemble),
parallel_iterations=1, parallel_iterations=1,
dtype=dtype) fn_output_signature=fn_output_signature)
else: else:
tensors = tree.map_structure(lambda x: x[None], tensors = tree.map_structure(lambda x: x[None],
tensors_0) tensors_0)
......
...@@ -52,7 +52,7 @@ def _add_restraints( ...@@ -52,7 +52,7 @@ def _add_restraints(
stiffness: unit.Unit, stiffness: unit.Unit,
rset: str, rset: str,
exclude_residues: Sequence[int]): exclude_residues: Sequence[int]):
"""Adds a harmonic potential that restrains the end-to-end distance.""" """Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"] assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce( force = openmm.CustomExternalForce(
......
...@@ -54,7 +54,6 @@ class AmberMinimizeTest(absltest.TestCase): ...@@ -54,7 +54,6 @@ class AmberMinimizeTest(absltest.TestCase):
max_attempts=1) max_attempts=1)
def test_iterative_relax(self): def test_iterative_relax(self):
# This test can occasionally fail because of nondeterminism in OpenMM.
prot = _load_test_protein( prot = _load_test_protein(
'alphafold/relax/testdata/with_violations.pdb' 'alphafold/relax/testdata/with_violations.pdb'
) )
......
...@@ -48,7 +48,7 @@ def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: ...@@ -48,7 +48,7 @@ def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
raise ValueError( raise ValueError(
f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.') f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
parser = PDB.PDBParser() parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str) handle = io.StringIO(pdb_str)
structure = parser.get_structure('', handle) structure = parser.get_structure('', handle)
......
...@@ -54,7 +54,8 @@ RUN conda update -qy conda \ ...@@ -54,7 +54,8 @@ RUN conda update -qy conda \
openmm=7.5.1 \ openmm=7.5.1 \
cudatoolkit==${CUDA}.3 \ cudatoolkit==${CUDA}.3 \
pdbfixer \ pdbfixer \
pip pip \
python=3.7
COPY . /app/alphafold COPY . /app/alphafold
RUN wget -q -P /app/alphafold/alphafold/common/ \ RUN wget -q -P /app/alphafold/alphafold/common/ \
...@@ -67,7 +68,7 @@ RUN pip3 install --upgrade pip \ ...@@ -67,7 +68,7 @@ RUN pip3 install --upgrade pip \
https://storage.googleapis.com/jax-releases/jax_releases.html https://storage.googleapis.com/jax-releases/jax_releases.html
# Apply OpenMM patch. # Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.8/site-packages WORKDIR /opt/conda/lib/python3.7/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch RUN patch -p0 < /app/alphafold/docker/openmm.patch
# We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk
......
...@@ -57,13 +57,17 @@ uniref90_database_path = os.path.join( ...@@ -57,13 +57,17 @@ uniref90_database_path = os.path.join(
# Path to the MGnify database for use by JackHMMER. # Path to the MGnify database for use by JackHMMER.
mgnify_database_path = os.path.join( mgnify_database_path = os.path.join(
DOWNLOAD_DIR, 'mgnify', 'mgy_clusters.fa') DOWNLOAD_DIR, 'mgnify', 'mgy_clusters_2018_08.fa')
# Path to the BFD database for use by HHblits. # Path to the BFD database for use by HHblits.
bfd_database_path = os.path.join( bfd_database_path = os.path.join(
DOWNLOAD_DIR, 'bfd', DOWNLOAD_DIR, 'bfd',
'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt')
# Path to the Small BFD database for use by JackHMMER.
small_bfd_database_path = os.path.join(
DOWNLOAD_DIR, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
# Path to the Uniclust30 database for use by HHblits. # Path to the Uniclust30 database for use by HHblits.
uniclust30_database_path = os.path.join( uniclust30_database_path = os.path.join(
DOWNLOAD_DIR, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') DOWNLOAD_DIR, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08')
...@@ -92,10 +96,11 @@ flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' ...@@ -92,10 +96,11 @@ flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
'to consider (ISO-8601 format - i.e. YYYY-MM-DD). ' 'to consider (ISO-8601 format - i.e. YYYY-MM-DD). '
'Important if folding historical test sets.') 'Important if folding historical test sets.')
flags.DEFINE_enum('preset', 'full_dbs', flags.DEFINE_enum('preset', 'full_dbs',
['full_dbs', 'casp14'], ['reduced_dbs', 'full_dbs', 'casp14'],
'Choose preset model configuration - no ensembling with ' 'Choose preset model configuration - no ensembling and '
'uniref90 + bfd + uniclust30 (full_dbs), or ' 'smaller genetic database config (reduced_dbs), no '
'8 model ensemblings with uniref90 + bfd + uniclust30 ' 'ensembling and full genetic database config (full_dbs) or '
'full genetic database config and 8 model ensemblings '
'(casp14).') '(casp14).')
flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
'to obtain a timing that excludes the compilation time, ' 'to obtain a timing that excludes the compilation time, '
...@@ -131,14 +136,22 @@ def main(argv): ...@@ -131,14 +136,22 @@ def main(argv):
target_fasta_paths.append(target_path) target_fasta_paths.append(target_path)
command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}') command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}')
for name, path in [('uniref90_database_path', uniref90_database_path), database_paths = [
('uniref90_database_path', uniref90_database_path),
('mgnify_database_path', mgnify_database_path), ('mgnify_database_path', mgnify_database_path),
('uniclust30_database_path', uniclust30_database_path),
('bfd_database_path', bfd_database_path),
('pdb70_database_path', pdb70_database_path), ('pdb70_database_path', pdb70_database_path),
('data_dir', data_dir), ('data_dir', data_dir),
('template_mmcif_dir', template_mmcif_dir), ('template_mmcif_dir', template_mmcif_dir),
('obsolete_pdbs_path', obsolete_pdbs_path)]: ('obsolete_pdbs_path', obsolete_pdbs_path),
]
if FLAGS.preset == 'reduced_dbs':
database_paths.append(('small_bfd_database_path', small_bfd_database_path))
else:
database_paths.extend([
('uniclust30_database_path', uniclust30_database_path),
('bfd_database_path', bfd_database_path),
])
for name, path in database_paths:
if path: if path:
mount, target_path = _create_mount(name, path) mount, target_path = _create_mount(name, path)
mounts.append(mount) mounts.append(mount)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment