Commit 2f0d89e7 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove duplicated code

parent a1597f3f
.dockerignore
docker/Dockerfile
__pycache__
*.whl
hh-suite
data
jobs
\ No newline at end of file
# How to Contribute
We welcome small patches related to bug fixes and documentation, but we do not
plan to make any major changes to this repository.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
![header](imgs/header.jpg)
# AlphaFold
This package provides an implementation of the inference pipeline of AlphaFold
v2.0. This is a completely new model that was entered in CASP14 and published in
Nature. For simplicity, we refer to this model as AlphaFold throughout the rest
of this document.
Any publication that discloses findings arising from using this source code or
the model parameters should [cite](#citing-this-work) the
[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).
![CASP14 predictions](imgs/casp14_predictions.gif)
## First time setup
The following steps are required in order to run AlphaFold:
1. Install [Docker](https://www.docker.com/).
* Install
[NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
for GPU support.
* Setup running
[Docker as a non-root user](https://docs.docker.com/engine/install/linux-postinstall/#manage-docker-as-a-non-root-user).
1. Download genetic databases (see below).
1. Download model parameters (see below).
1. Check that AlphaFold will be able to use a GPU by running:
```bash
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
```
The output of this command should show a list of your GPUs. If it doesn't,
check if you followed all steps correctly when setting up the
[NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
or take a look at the following
[NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).
### Genetic databases
This step requires `rsync` and `aria2c` to be installed on your machine.
AlphaFold needs multiple genetic (sequence) databases to run:
* [UniRef90](https://www.uniprot.org/help/uniref),
* [MGnify](https://www.ebi.ac.uk/metagenomics/),
* [BFD](https://bfd.mmseqs.com/),
* [Uniclust30](https://uniclust.mmseqs.com/),
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format).
We provide a script `scripts/download_all_data.sh` that can be used to download
and set up all of these databases. This should take 8–12 hours.
:ledger: **Note: The total download size is around 428 GB and the total size
when unzipped is 2.2 TB. Please make sure you have a large enough hard drive
space, bandwidth and time to download.**
This script will also download the model parameter files. Once the script has
finished, you should have the following directory structure:
```
$DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 428 GB)
bfd/ # ~ 1.8 TB (download: 271.6 GB)
# 6 files.
mgnify/ # ~ 64 GB (download: 32.9 GB)
mgy_clusters.fa
params/ # ~ 3.5 GB (download: 3.5 GB)
# 5 CASP14 models,
# 5 pTM models,
# LICENSE,
# = 11 files.
pdb70/ # ~ 56 GB (download: 19.5 GB)
# 9 files.
pdb_mmcif/ # ~ 206 GB (download: 46 GB)
mmcif_files/
# About 180,000 .cif files.
obsolete.dat
uniclust30/ # ~ 87 GB (download: 24.9 GB)
uniclust30_2018_08/
# 13 files.
uniref90/ # ~ 59 GB (download: 29.7 GB)
uniref90.fasta
```
### Model parameters
While the AlphaFold code is licensed under the Apache 2.0 License, the AlphaFold
parameters are made available for non-commercial use only under the terms of the
CC BY-NC 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below
for more detail.
The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for:
* 5 models which were used during CASP14, and were extensively validated for
structure prediction quality (see Jumper et al. 2021, Suppl. Methods 1.12
for details).
* 5 pTM models, which were fine-tuned to produce pTM (predicted TM-score) and
predicted aligned error values alongside their structure predictions (see
Jumper et al. 2021, Suppl. Methods 1.9.7 for details).
## Running AlphaFold
**The simplest way to run AlphaFold is using the provided Docker script.** This
was tested on Google Cloud with a machine using the `nvidia-gpu-cloud-image`
with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional
3 TB disk, and an A100 GPU.
1. Clone this repository and `cd` into it.
```bash
git clone https://github.com/deepmind/alphafold.git
```
1. Modify `DOWNLOAD_DIR` in `docker/run_docker.py` to be the path to the
directory containing the downloaded databases.
1. Build the Docker image:
```bash
docker build -f docker/Dockerfile -t alphafold .
```
1. Install the `run_docker.py` dependencies. Note: You may optionally wish to
create a
[Python Virtual Environment](https://docs.python.org/3/tutorial/venv.html)
to prevent conflicts with your system's Python environment.
```bash
pip3 install -r docker/requirements.txt
```
1. Run `run_docker.py` pointing to a FASTA file containing the protein sequence
for which you wish to predict the structure. If you are predicting the
structure of a protein that is already in PDB and you wish to avoid using it
as a template, then `max_template_date` must be set to be before the release
date of the structure. For example, for the T1050 CASP14 target:
```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14
```
By default, Alphafold will attempt to use all visible GPU devices. To use a
subset, specify a comma-separated list of GPU UUID(s) or index(es) using the
`--gpu_devices` flag. See
[GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration)
for more details.
1. You can control AlphaFold speed / quality tradeoff by adding either
`--preset=full_dbs` or `--preset=casp14` to the run command. We provide the
following presets:
* **casp14**: This preset uses the same settings as were used in CASP14.
It runs with all genetic databases and with 8 ensemblings.
* **full_dbs**: The model in this preset is 8 times faster than the
`casp14` preset with a very minor quality drop (-0.1 average GDT drop on
CASP14 domains). It runs with all genetic databases and with no
ensembling.
Running the command above with the `casp14` preset would look like this:
```bash
python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 --preset=casp14
```
### AlphaFold output
The outputs will be in a subfolder of `output_dir` in `run_docker.py`. They
include the computed MSAs, unrelaxed structures, relaxed structures, ranked
structures, raw model outputs, prediction metadata, and section timings. The
`output_dir` directory will have the following structure:
```
output_dir/
features.pkl
ranked_{0,1,2,3,4}.pdb
ranking_debug.json
relaxed_model_{1,2,3,4,5}.pdb
result_model_{1,2,3,4,5}.pkl
timings.json
unrelaxed_model_{1,2,3,4,5}.pdb
msas/
bfd_uniclust_hits.a3m
mgnify_hits.sto
uniref90_hits.sto
```
The contents of each output file are as follows:
* `features.pkl` – A `pickle` file containing the input feature Numpy arrays
used by the models to produce the structures.
* `unrelaxed_model_*.pdb` – A PDB format text file containing the predicted
structure, exactly as outputted by the model.
* `relaxed_model_*.pdb` – A PDB format text file containing the predicted
structure, after performing an Amber relaxation procedure on the unrelaxed
structure prediction, see Jumper et al. 2021, Suppl. Methods 1.8.6 for
details.
* `ranked_*.pdb` – A PDB format text file containing the relaxed predicted
structures, after reordering by model confidence. Here `ranked_0.pdb` should
contain the prediction with the highest confidence, and `ranked_4.pdb` the
prediction with the lowest confidence. To rank model confidence, we use
predicted LDDT (pLDDT), see Jumper et al. 2021, Suppl. Methods 1.9.6 for
details.
* `ranking_debug.json` – A JSON format text file containing the pLDDT values
used to perform the model ranking, and a mapping back to the original model
names.
* `timings.json` – A JSON format text file containing the times taken to run
each section of the AlphaFold pipeline.
* `msas/` - A directory containing the files describing the various genetic
tool hits that were used to construct the input MSA.
* `result_model_*.pkl` – A `pickle` file containing a nested dictionary of the
various Numpy arrays directly produced by the model. In addition to the
output of the structure module, this includes auxiliary outputs such as
distograms and pLDDT scores. If using the pTM models then the pTM logits
will also be contained in this file.
This code has been tested to match mean top-1 accuracy on a CASP14 test set with
pLDDT ranking over 5 model predictions (some CASP targets were run with earlier
versions of AlphaFold and some had manual interventions; see our forthcoming
publication for details). Some targets such as T1064 may also have high
individual run variance over random seeds.
## Inferencing many proteins
The provided inference script is optimized for predicting the structure of a
single protein, and it will compile the neural network to be specialized to
exactly the size of the sequence, MSA, and templates. For large proteins, the
compile time is a negligible fraction of the runtime, but it may become more
significant for small proteins or if the multi-sequence alignments are already
precomputed. In the bulk inference case, it may make sense to use our
`make_fixed_size` function to pad the inputs to a uniform size, thereby reducing
the number of compilations required.
We do not provide a bulk inference script, but it should be straightforward to
develop on top of the `RunModel.predict` method with a parallel system for
precomputing multi-sequence alignments. Alternatively, this script can be run
repeatedly with only moderate overhead.
## Note on reproducibility
AlphaFold's output for a small number of proteins has high inter-run variance,
and may be affected by changes in the input data. The CASP14 target T1064 is a
notable example; the large number of SARS-CoV-2-related sequences recently
deposited changes its MSA significantly. This variability is somewhat mitigated
by the model selection process; running 5 models and taking the most confident.
To reproduce the results of our CASP14 system as closely as possible you must
use the same database versions we used in CASP. These may not match the default
versions downloaded by our scripts.
For genetics:
* UniRef90:
[v2020_01](https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2020_01/uniref/)
* MGnify:
[v2018_12](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/)
* Uniclust30: [v2018_08](http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/)
* BFD: [only version available](https://bfd.mmseqs.com/)
For templates:
* PDB: (downloaded 2020-05-14)
* PDB70: (downloaded 2020-05-13)
An alternative for templates is to use the latest PDB and PDB70, but pass the
flag `--max_template_date=2020-05-14`, which restricts templates only to
structures that were available at the start of CASP14.
## Citing this work
If you use the code or data in this package, please cite:
```tex
@Article{AlphaFold2021,
author = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
journal = {Nature},
title = {Highly accurate protein structure prediction with {AlphaFold}},
year = {2021},
doi = {10.1038/s41586-021-03819-2},
note = {(Accelerated article preview)},
}
```
## Acknowledgements
AlphaFold communicates with and/or references the following separate libraries
and packages:
* [Abseil](https://github.com/abseil/abseil-py)
* [Biopython](https://biopython.org)
* [Chex](https://github.com/deepmind/chex)
* [Docker](https://www.docker.com)
* [HH Suite](https://github.com/soedinglab/hh-suite)
* [HMMER Suite](http://eddylab.org/software/hmmer)
* [Haiku](https://github.com/deepmind/dm-haiku)
* [Immutabledict](https://github.com/corenting/immutabledict)
* [JAX](https://github.com/google/jax/)
* [Kalign](https://msa.sbc.su.se/cgi-bin/msa.cgi)
* [ML Collections](https://github.com/google/ml_collections)
* [NumPy](https://numpy.org)
* [OpenMM](https://github.com/openmm/openmm)
* [OpenStructure](https://openstructure.org)
* [SciPy](https://scipy.org)
* [Sonnet](https://github.com/deepmind/sonnet)
* [TensorFlow](https://github.com/tensorflow/tensorflow)
* [Tree](https://github.com/deepmind/tree)
We thank all their contributors and maintainers!
## License and Disclaimer
This is not an officially supported Google product.
Copyright 2021 DeepMind Technologies Limited.
### AlphaFold Code License
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at https://www.apache.org/licenses/LICENSE-2.0.
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
### Model Parameters License
The AlphaFold parameters are made available for non-commercial use only, under
the terms of the Creative Commons Attribution-NonCommercial 4.0 International
(CC BY-NC 4.0) license. You can find details at:
https://creativecommons.org/licenses/by-nc/4.0/legalcode
### Third-party software
Use of the third-party software, libraries or code referred to in the
[Acknowledgements](#acknowledgements) section above may be governed by separate
terms and conditions or license provisions. Your use of the third-party
software, libraries or code is subject to any such terms and you should check
that you can comply with any applicable restrictions or terms and conditions
before use.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An implementation of the inference pipeline of AlphaFold v2.0."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common data types and constants used within Alphafold."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for processing confidence metrics."""
from typing import Dict, Optional, Tuple
import numpy as np
import scipy.special
def compute_plddt(logits: np.ndarray) -> np.ndarray:
"""Computes per-residue pLDDT from logits.
Args:
logits: [num_res, num_bins] output from the PredictedLDDTHead.
Returns:
plddt: [num_res] per-residue pLDDT.
"""
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
probs = scipy.special.softmax(logits, axis=-1)
predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
return predicted_lddt_ca * 100
def _calculate_bin_centers(breaks: np.ndarray):
"""Gets the bin centers from the bin edges.
Args:
breaks: [num_bins - 1] the error bin edges.
Returns:
bin_centers: [num_bins] the error bin centers.
"""
step = (breaks[1] - breaks[0])
# Add half-step to get the center
bin_centers = breaks + step / 2
# Add a catch-all bin at the end.
bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
axis=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: np.ndarray,
aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Calculates expected aligned distance errors for every pair of residues.
Args:
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
probs for each error bin, for each pair of residues.
Returns:
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
# Tuple of expected aligned distance error and max possible error.
return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
np.asarray(bin_centers[-1]))
def compute_predicted_aligned_error(
logits: np.ndarray,
breaks: np.ndarray) -> Dict[str, np.ndarray]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins - 1] the error bin edges.
Returns:
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
aligned_confidence_probs = scipy.special.softmax(
logits,
axis=-1)
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
alignment_confidence_breaks=breaks,
aligned_distance_error_probs=aligned_confidence_probs))
return {
'aligned_confidence_probs': aligned_confidence_probs,
'predicted_aligned_error': predicted_aligned_error,
'max_predicted_aligned_error': max_predicted_aligned_error,
}
def predicted_tm_score(
logits: np.ndarray,
breaks: np.ndarray,
residue_weights: Optional[np.ndarray] = None,
asym_id: Optional[np.ndarray] = None,
interface: bool = False) -> np.ndarray:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
"""
# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
# exp. resolved head's probability.
if residue_weights is None:
residue_weights = np.ones(logits.shape[0])
bin_centers = _calculate_bin_centers(breaks)
num_res = int(np.sum(residue_weights))
# Clip num_res to avoid negative/undefined d0.
clipped_num_res = max(num_res, 19)
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
# Convert logits to probs.
probs = scipy.special.softmax(logits, axis=-1)
# TM-Score term for every bin.
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
# E_distances tm(distance).
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
pair_mask = np.ones(shape=(num_res, num_res), dtype=bool)
if interface:
pair_mask *= asym_id[:, None] != asym_id[None, :]
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
pair_residue_weights, axis=-1, keepdims=True))
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
from alphafold.common import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
'because these cannot be written to PDB format.')
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors))
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}')
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ['X']
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append('MODEL 1')
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(_chain_end(
atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
residue_index[i - 1]))
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = 'ATOM'
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
alt_loc = ''
insertion_code = ''
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
chain_ids[chain_index[-1]], residue_index[-1]))
pdb_lines.append('ENDMDL')
pdb_lines.append('END')
# Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values.
Returns:
A protein instance.
"""
fold_output = result['structure_module']
dist_per_residue = np.zeros_like(fold_output['final_atom_mask'])
plddt = np.expand_dims(result['plddt'],axis=1)
plddt = np.repeat(plddt, residue_constants.atom_type_num, axis=1)
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features['asym_id'])
else:
chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
if b_factors is None:
b_factors = np.zeros_like(fold_output['final_atom_mask'])
return Protein(
aatype=_maybe_remove_leading_dim(features['aatype']),
atom_positions=fold_output['final_atom_positions'],
atom_mask=fold_output['final_atom_mask'],
residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
chain_index=chain_index,
b_factors=b_factors)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for protein."""
import os
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.common import protein
from alphafold.common import residue_constants
import numpy as np
# Internal import (7716).
TEST_DATA_DIR = 'alphafold/common/testdata/'
class ProteinTest(parameterized.TestCase):
def _check_shapes(self, prot, num_res):
"""Check that the processed shapes are correct."""
num_atoms = residue_constants.atom_type_num
self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape)
self.assertEqual((num_res,), prot.aatype.shape)
self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
self.assertEqual((num_res,), prot.residue_index.shape)
self.assertEqual((num_res,), prot.chain_index.shape)
self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
@parameterized.named_parameters(
dict(testcase_name='chain_A',
pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1),
dict(testcase_name='chain_B',
pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1),
dict(testcase_name='multichain',
pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2))
def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains):
pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
pdb_file)
with open(pdb_file) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string, chain_id)
self._check_shapes(prot, num_res)
self.assertGreaterEqual(prot.aatype.min(), 0)
# Allow equal since unknown restypes have index equal to restype_num.
self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
self.assertLen(np.unique(prot.chain_index), num_chains)
def test_to_pdb(self):
with open(
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string)
pdb_string_reconstr = protein.to_pdb(prot)
for line in pdb_string_reconstr.splitlines():
self.assertLen(line, 80)
prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
np.testing.assert_array_almost_equal(
prot_reconstr.atom_positions, prot.atom_positions)
np.testing.assert_array_almost_equal(
prot_reconstr.atom_mask, prot.atom_mask)
np.testing.assert_array_equal(
prot_reconstr.residue_index, prot.residue_index)
np.testing.assert_array_equal(
prot_reconstr.chain_index, prot.chain_index)
np.testing.assert_array_almost_equal(
prot_reconstr.b_factors, prot.b_factors)
def test_ideal_atom_mask(self):
with open(
os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
'2rbg.pdb')) as f:
pdb_string = f.read()
prot = protein.from_pdb_string(pdb_string)
ideal_mask = protein.ideal_atom_mask(prot)
non_ideal_residues = set([102] + list(range(127, 286)))
for i, (res, atom_mask) in enumerate(
zip(prot.residue_index, prot.atom_mask)):
if res in non_ideal_residues:
self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
else:
self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
def test_too_many_chains(self):
num_res = protein.PDB_MAX_CHAINS + 1
num_atom_type = residue_constants.atom_type_num
with self.assertRaises(ValueError):
_ = protein.Protein(
atom_positions=np.random.random([num_res, num_atom_type, 3]),
aatype=np.random.randint(0, 21, [num_res]),
atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32),
residue_index=np.arange(1, num_res+1),
chain_index=np.arange(num_res),
b_factors=np.random.uniform(1, 100, [num_res]))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import collections
import functools
import os
from typing import List, Mapping, Tuple
import numpy as np
import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
'ALA': [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'CYS': [['N', 'CA', 'CB', 'SG']],
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLY': [],
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
['CB', 'CG', 'SD', 'CE']],
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
'SER': [['N', 'CA', 'CB', 'OG']],
'THR': [['N', 'CA', 'CB', 'OG1']],
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'VAL': [['N', 'CA', 'CB', 'CG1']],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
'ALA': [
['N', 0, (-0.525, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.529, -0.774, -1.205)],
['O', 3, (0.627, 1.062, 0.000)],
],
'ARG': [
['N', 0, (-0.524, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.524, -0.778, -1.209)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.616, 1.390, -0.000)],
['CD', 5, (0.564, 1.414, 0.000)],
['NE', 6, (0.539, 1.357, -0.000)],
['NH1', 7, (0.206, 2.301, 0.000)],
['NH2', 7, (2.078, 0.978, -0.000)],
['CZ', 7, (0.758, 1.093, -0.000)],
],
'ASN': [
['N', 0, (-0.536, 1.357, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.531, -0.787, -1.200)],
['O', 3, (0.625, 1.062, 0.000)],
['CG', 4, (0.584, 1.399, 0.000)],
['ND2', 5, (0.593, -1.188, 0.001)],
['OD1', 5, (0.633, 1.059, 0.000)],
],
'ASP': [
['N', 0, (-0.525, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, 0.000, -0.000)],
['CB', 0, (-0.526, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.593, 1.398, -0.000)],
['OD1', 5, (0.610, 1.091, 0.000)],
['OD2', 5, (0.592, -1.101, -0.003)],
],
'CYS': [
['N', 0, (-0.522, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, 0.000)],
['CB', 0, (-0.519, -0.773, -1.212)],
['O', 3, (0.625, 1.062, -0.000)],
['SG', 4, (0.728, 1.653, 0.000)],
],
'GLN': [
['N', 0, (-0.526, 1.361, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.525, -0.779, -1.207)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.615, 1.393, 0.000)],
['CD', 5, (0.587, 1.399, -0.000)],
['NE2', 6, (0.593, -1.189, -0.001)],
['OE1', 6, (0.634, 1.060, 0.000)],
],
'GLU': [
['N', 0, (-0.528, 1.361, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.526, -0.781, -1.207)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.615, 1.392, 0.000)],
['CD', 5, (0.600, 1.397, 0.000)],
['OE1', 6, (0.607, 1.095, -0.000)],
['OE2', 6, (0.589, -1.104, -0.001)],
],
'GLY': [
['N', 0, (-0.572, 1.337, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.517, -0.000, -0.000)],
['O', 3, (0.626, 1.062, -0.000)],
],
'HIS': [
['N', 0, (-0.527, 1.360, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.525, -0.778, -1.208)],
['O', 3, (0.625, 1.063, 0.000)],
['CG', 4, (0.600, 1.370, -0.000)],
['CD2', 5, (0.889, -1.021, 0.003)],
['ND1', 5, (0.744, 1.160, -0.000)],
['CE1', 5, (2.030, 0.851, 0.002)],
['NE2', 5, (2.145, -0.466, 0.004)],
],
'ILE': [
['N', 0, (-0.493, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.536, -0.793, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.534, 1.437, -0.000)],
['CG2', 4, (0.540, -0.785, -1.199)],
['CD1', 5, (0.619, 1.391, 0.000)],
],
'LEU': [
['N', 0, (-0.520, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.522, -0.773, -1.214)],
['O', 3, (0.625, 1.063, -0.000)],
['CG', 4, (0.678, 1.371, 0.000)],
['CD1', 5, (0.530, 1.430, -0.000)],
['CD2', 5, (0.535, -0.774, 1.200)],
],
'LYS': [
['N', 0, (-0.526, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.524, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.619, 1.390, 0.000)],
['CD', 5, (0.559, 1.417, 0.000)],
['CE', 6, (0.560, 1.416, 0.000)],
['NZ', 7, (0.554, 1.387, 0.000)],
],
'MET': [
['N', 0, (-0.521, 1.364, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.210)],
['O', 3, (0.625, 1.062, -0.000)],
['CG', 4, (0.613, 1.391, -0.000)],
['SD', 5, (0.703, 1.695, 0.000)],
['CE', 6, (0.320, 1.786, -0.000)],
],
'PHE': [
['N', 0, (-0.518, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, -0.000)],
['CB', 0, (-0.525, -0.776, -1.212)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.607, 1.377, 0.000)],
['CD1', 5, (0.709, 1.195, -0.000)],
['CD2', 5, (0.706, -1.196, 0.000)],
['CE1', 5, (2.102, 1.198, -0.000)],
['CE2', 5, (2.098, -1.201, -0.000)],
['CZ', 5, (2.794, -0.003, -0.001)],
],
'PRO': [
['N', 0, (-0.566, 1.351, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, 0.000)],
['CB', 0, (-0.546, -0.611, -1.293)],
['O', 3, (0.621, 1.066, 0.000)],
['CG', 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
'SER': [
['N', 0, (-0.529, 1.360, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.518, -0.777, -1.211)],
['O', 3, (0.626, 1.062, -0.000)],
['OG', 4, (0.503, 1.325, 0.000)],
],
'THR': [
['N', 0, (-0.517, 1.364, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, -0.000)],
['CB', 0, (-0.516, -0.793, -1.215)],
['O', 3, (0.626, 1.062, 0.000)],
['CG2', 4, (0.550, -0.718, -1.228)],
['OG1', 4, (0.472, 1.353, 0.000)],
],
'TRP': [
['N', 0, (-0.521, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.212)],
['O', 3, (0.627, 1.062, 0.000)],
['CG', 4, (0.609, 1.370, -0.000)],
['CD1', 5, (0.824, 1.091, 0.000)],
['CD2', 5, (0.854, -1.148, -0.005)],
['CE2', 5, (2.186, -0.678, -0.007)],
['CE3', 5, (0.622, -2.530, -0.007)],
['NE1', 5, (2.140, 0.690, -0.004)],
['CH2', 5, (3.028, -2.890, -0.013)],
['CZ2', 5, (3.283, -1.543, -0.011)],
['CZ3', 5, (1.715, -3.389, -0.011)],
],
'TYR': [
['N', 0, (-0.522, 1.362, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, -0.000, -0.000)],
['CB', 0, (-0.522, -0.776, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG', 4, (0.607, 1.382, -0.000)],
['CD1', 5, (0.716, 1.195, -0.000)],
['CD2', 5, (0.713, -1.194, -0.001)],
['CE1', 5, (2.107, 1.200, -0.002)],
['CE2', 5, (2.104, -1.201, -0.003)],
['OH', 5, (4.168, -0.002, -0.005)],
['CZ', 5, (2.791, -0.001, -0.003)],
],
'VAL': [
['N', 0, (-0.494, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.533, -0.795, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.540, 1.429, -0.000)],
['CG2', 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
'ALA': ['C', 'CA', 'CB', 'N', 'O'],
'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
'GLY': ['C', 'CA', 'N', 'O'],
'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
'CH2', 'N', 'NE1', 'O'],
'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
'OH'],
'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
residue_atom_renaming_swaps = {
'ASP': {'OD1': 'OD2'},
'GLU': {'OE1': 'OE2'},
'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
'C': 1.7,
'N': 1.55,
'O': 1.52,
'S': 1.8,
}
Bond = collections.namedtuple(
'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
BondAngle = collections.namedtuple(
'BondAngle',
['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]]]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: Dict that maps resname -> list of Bond tuples.
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
"""
stereo_chemical_props_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
)
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split('-')
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev)))
residue_bonds['UNK'] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split('-')
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(atom1, atom2, atom3,
float(angle_degree) / 180. * np.pi,
float(stddev_degree) / 180. * np.pi))
residue_bond_angles['UNK'] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return '-'.join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(bond1.length**2 + bond2.length**2
- 2 * bond1.length * bond2.length * np.cos(gamma))
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
(dl_db1 * bond1.stddev)**2 +
(dl_db2 * bond2.stddev)**2)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev))
return (residue_bonds,
residue_virtual_bonds,
residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
'S', 'T', 'W', 'Y', 'V'
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ['X']
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str,
mapping: Mapping[str, int],
map_unknown_to_x: bool = False) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
'without any gaps. Got: %s' % sorted(mapping.values()))
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping['X'])
else:
raise ValueError(f'Invalid character in the sequence: {aa_type}')
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
'A': 'ALA',
'R': 'ARG',
'N': 'ASN',
'D': 'ASP',
'C': 'CYS',
'Q': 'GLN',
'E': 'GLU',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'L': 'LEU',
'K': 'LYS',
'M': 'MET',
'F': 'PHE',
'P': 'PRO',
'S': 'SER',
'T': 'THR',
'W': 'TRP',
'Y': 'TYR',
'V': 'VAL',
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = 'UNK'
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
'A': 0,
'B': 2,
'C': 1,
'D': 2,
'E': 3,
'F': 4,
'G': 5,
'H': 6,
'I': 7,
'J': 20,
'K': 8,
'L': 9,
'M': 10,
'N': 11,
'O': 20,
'P': 12,
'Q': 13,
'R': 14,
'S': 15,
'T': 16,
'U': 1,
'V': 17,
'W': 18,
'X': 20,
'Y': 19,
'Z': 3,
'-': 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: 'A',
1: 'C', # Also U.
2: 'D', # Also B.
3: 'E', # Also Z.
4: 'F',
5: 'G',
6: 'H',
7: 'I',
8: 'K',
9: 'L',
10: 'M',
11: 'N',
12: 'P',
13: 'Q',
14: 'R',
15: 'S',
16: 'T',
17: 'V',
18: 'W',
19: 'Y',
20: 'X', # Includes J and O.
21: '-',
}
restypes_with_x_and_gap = restypes + ['X', '-']
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap)))
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1]*(4-len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree.map_structure(
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
chi_angles_atom_indices = np.array([
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices])
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype,
atom14idx, :] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {name: np.array(pos) for name, _, pos
in rigid_group_atom_positions[resname]}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['N'] - atom_positions['CA'],
ey=np.array([1., 0., 0.]),
translation=atom_positions['N'])
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['C'] - atom_positions['CA'],
ey=atom_positions['CA'] - atom_positions['N'],
translation=atom_positions['C'])
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name] for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2])
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1., 0., 0.]),
translation=axis_end_atom_position)
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(overlap_tolerance=1.5,
bond_length_tolerance_factor=15):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that residue_constants generates correct values."""
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.common import residue_constants
import numpy as np
class ResidueConstantsTest(parameterized.TestCase):
@parameterized.parameters(
('ALA', 0),
('CYS', 1),
('HIS', 2),
('MET', 3),
('LYS', 4),
('ARG', 4),
)
def testChiAnglesAtoms(self, residue_name, chi_num):
chi_angles_atoms = residue_constants.chi_angles_atoms[residue_name]
self.assertLen(chi_angles_atoms, chi_num)
for chi_angle_atoms in chi_angles_atoms:
self.assertLen(chi_angle_atoms, 4)
def testChiGroupsForAtom(self):
for k, chi_groups in residue_constants.chi_groups_for_atom.items():
res_name, atom_name = k
for chi_group_i, atom_i in chi_groups:
self.assertEqual(
atom_name,
residue_constants.chi_angles_atoms[res_name][chi_group_i][atom_i])
@parameterized.parameters(
('ALA', 5), ('ARG', 11), ('ASN', 8), ('ASP', 8), ('CYS', 6), ('GLN', 9),
('GLU', 9), ('GLY', 4), ('HIS', 10), ('ILE', 8), ('LEU', 8), ('LYS', 9),
('MET', 8), ('PHE', 11), ('PRO', 7), ('SER', 6), ('THR', 7), ('TRP', 14),
('TYR', 12), ('VAL', 7)
)
def testResidueAtoms(self, atom_name, num_residue_atoms):
residue_atoms = residue_constants.residue_atoms[atom_name]
self.assertLen(residue_atoms, num_residue_atoms)
def testStandardAtomMask(self):
with self.subTest('Check shape'):
self.assertEqual(residue_constants.STANDARD_ATOM_MASK.shape, (21, 37,))
with self.subTest('Check values'):
str_to_row = lambda s: [c == '1' for c in s] # More clear/concise.
np.testing.assert_array_equal(
residue_constants.STANDARD_ATOM_MASK,
np.array([
# NB This was defined by c+p but looks sane.
str_to_row('11111 '), # ALA
str_to_row('111111 1 1 11 1 '), # ARG
str_to_row('111111 11 '), # ASP
str_to_row('111111 11 '), # ASN
str_to_row('11111 1 '), # CYS
str_to_row('111111 1 11 '), # GLU
str_to_row('111111 1 11 '), # GLN
str_to_row('111 1 '), # GLY
str_to_row('111111 11 1 1 '), # HIS
str_to_row('11111 11 1 '), # ILE
str_to_row('111111 11 '), # LEU
str_to_row('111111 1 1 1 '), # LYS
str_to_row('111111 11 '), # MET
str_to_row('111111 11 11 1 '), # PHE
str_to_row('111111 1 '), # PRO
str_to_row('11111 1 '), # SER
str_to_row('11111 1 1 '), # THR
str_to_row('111111 11 11 1 1 11 '), # TRP
str_to_row('111111 11 11 11 '), # TYR
str_to_row('11111 11 '), # VAL
str_to_row(' '), # UNK
]))
with self.subTest('Check row totals'):
# Check each row has the right number of atoms.
for row, restype in enumerate(residue_constants.restypes): # A, R, ...
long_restype = residue_constants.restype_1to3[restype] # ALA, ARG, ...
atoms_names = residue_constants.residue_atoms[
long_restype] # ['C', 'CA', 'CB', 'N', 'O'], ...
self.assertLen(atoms_names,
residue_constants.STANDARD_ATOM_MASK[row, :].sum(),
long_restype)
def testAtomTypes(self):
self.assertEqual(residue_constants.atom_type_num, 37)
self.assertEqual(residue_constants.atom_types[0], 'N')
self.assertEqual(residue_constants.atom_types[1], 'CA')
self.assertEqual(residue_constants.atom_types[2], 'C')
self.assertEqual(residue_constants.atom_types[3], 'CB')
self.assertEqual(residue_constants.atom_types[4], 'O')
self.assertEqual(residue_constants.atom_order['N'], 0)
self.assertEqual(residue_constants.atom_order['CA'], 1)
self.assertEqual(residue_constants.atom_order['C'], 2)
self.assertEqual(residue_constants.atom_order['CB'], 3)
self.assertEqual(residue_constants.atom_order['O'], 4)
self.assertEqual(residue_constants.atom_type_num, 37)
def testRestypes(self):
three_letter_restypes = [
residue_constants.restype_1to3[r] for r in residue_constants.restypes]
for restype, exp_restype in zip(
three_letter_restypes, sorted(residue_constants.restype_1to3.values())):
self.assertEqual(restype, exp_restype)
self.assertEqual(residue_constants.restype_num, 20)
def testSequenceToOneHotHHBlits(self):
one_hot = residue_constants.sequence_to_onehot(
'ABCDEFGHIJKLMNOPQRSTUVWXYZ-', residue_constants.HHBLITS_AA_TO_ID)
exp_one_hot = np.array(
[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
np.testing.assert_array_equal(one_hot, exp_one_hot)
def testSequenceToOneHotStandard(self):
one_hot = residue_constants.sequence_to_onehot(
'ARNDCQEGHILKMFPSTWYV', residue_constants.restype_order)
np.testing.assert_array_equal(one_hot, np.eye(20))
def testSequenceToOneHotUnknownMapping(self):
seq = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
expected_out = np.zeros([26, 21])
for row, position in enumerate(
[0, 20, 4, 3, 6, 13, 7, 8, 9, 20, 11, 10, 12, 2, 20, 14, 5, 1, 15, 16,
20, 19, 17, 20, 18, 20]):
expected_out[row, position] = 1
aa_types = residue_constants.sequence_to_onehot(
sequence=seq,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True)
self.assertTrue((aa_types == expected_out).all())
@parameterized.named_parameters(
('lowercase', 'aaa'), # Insertions in A3M.
('gaps', '---'), # Gaps in A3M.
('dots', '...'), # Gaps in A3M.
('metadata', '>TEST'), # FASTA metadata line.
)
def testSequenceToOneHotUnknownMappingError(self, seq):
with self.assertRaises(ValueError):
residue_constants.sequence_to_onehot(
sequence=seq,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True)
if __name__ == '__main__':
absltest.main()
Bond Residue Mean StdDev
CA-CB ALA 1.520 0.021
N-CA ALA 1.459 0.020
CA-C ALA 1.525 0.026
C-O ALA 1.229 0.019
CA-CB ARG 1.535 0.022
CB-CG ARG 1.521 0.027
CG-CD ARG 1.515 0.025
CD-NE ARG 1.460 0.017
NE-CZ ARG 1.326 0.013
CZ-NH1 ARG 1.326 0.013
CZ-NH2 ARG 1.326 0.013
N-CA ARG 1.459 0.020
CA-C ARG 1.525 0.026
C-O ARG 1.229 0.019
CA-CB ASN 1.527 0.026
CB-CG ASN 1.506 0.023
CG-OD1 ASN 1.235 0.022
CG-ND2 ASN 1.324 0.025
N-CA ASN 1.459 0.020
CA-C ASN 1.525 0.026
C-O ASN 1.229 0.019
CA-CB ASP 1.535 0.022
CB-CG ASP 1.513 0.021
CG-OD1 ASP 1.249 0.023
CG-OD2 ASP 1.249 0.023
N-CA ASP 1.459 0.020
CA-C ASP 1.525 0.026
C-O ASP 1.229 0.019
CA-CB CYS 1.526 0.013
CB-SG CYS 1.812 0.016
N-CA CYS 1.459 0.020
CA-C CYS 1.525 0.026
C-O CYS 1.229 0.019
CA-CB GLU 1.535 0.022
CB-CG GLU 1.517 0.019
CG-CD GLU 1.515 0.015
CD-OE1 GLU 1.252 0.011
CD-OE2 GLU 1.252 0.011
N-CA GLU 1.459 0.020
CA-C GLU 1.525 0.026
C-O GLU 1.229 0.019
CA-CB GLN 1.535 0.022
CB-CG GLN 1.521 0.027
CG-CD GLN 1.506 0.023
CD-OE1 GLN 1.235 0.022
CD-NE2 GLN 1.324 0.025
N-CA GLN 1.459 0.020
CA-C GLN 1.525 0.026
C-O GLN 1.229 0.019
N-CA GLY 1.456 0.015
CA-C GLY 1.514 0.016
C-O GLY 1.232 0.016
CA-CB HIS 1.535 0.022
CB-CG HIS 1.492 0.016
CG-ND1 HIS 1.369 0.015
CG-CD2 HIS 1.353 0.017
ND1-CE1 HIS 1.343 0.025
CD2-NE2 HIS 1.415 0.021
CE1-NE2 HIS 1.322 0.023
N-CA HIS 1.459 0.020
CA-C HIS 1.525 0.026
C-O HIS 1.229 0.019
CA-CB ILE 1.544 0.023
CB-CG1 ILE 1.536 0.028
CB-CG2 ILE 1.524 0.031
CG1-CD1 ILE 1.500 0.069
N-CA ILE 1.459 0.020
CA-C ILE 1.525 0.026
C-O ILE 1.229 0.019
CA-CB LEU 1.533 0.023
CB-CG LEU 1.521 0.029
CG-CD1 LEU 1.514 0.037
CG-CD2 LEU 1.514 0.037
N-CA LEU 1.459 0.020
CA-C LEU 1.525 0.026
C-O LEU 1.229 0.019
CA-CB LYS 1.535 0.022
CB-CG LYS 1.521 0.027
CG-CD LYS 1.520 0.034
CD-CE LYS 1.508 0.025
CE-NZ LYS 1.486 0.025
N-CA LYS 1.459 0.020
CA-C LYS 1.525 0.026
C-O LYS 1.229 0.019
CA-CB MET 1.535 0.022
CB-CG MET 1.509 0.032
CG-SD MET 1.807 0.026
SD-CE MET 1.774 0.056
N-CA MET 1.459 0.020
CA-C MET 1.525 0.026
C-O MET 1.229 0.019
CA-CB PHE 1.535 0.022
CB-CG PHE 1.509 0.017
CG-CD1 PHE 1.383 0.015
CG-CD2 PHE 1.383 0.015
CD1-CE1 PHE 1.388 0.020
CD2-CE2 PHE 1.388 0.020
CE1-CZ PHE 1.369 0.019
CE2-CZ PHE 1.369 0.019
N-CA PHE 1.459 0.020
CA-C PHE 1.525 0.026
C-O PHE 1.229 0.019
CA-CB PRO 1.531 0.020
CB-CG PRO 1.495 0.050
CG-CD PRO 1.502 0.033
CD-N PRO 1.474 0.014
N-CA PRO 1.468 0.017
CA-C PRO 1.524 0.020
C-O PRO 1.228 0.020
CA-CB SER 1.525 0.015
CB-OG SER 1.418 0.013
N-CA SER 1.459 0.020
CA-C SER 1.525 0.026
C-O SER 1.229 0.019
CA-CB THR 1.529 0.026
CB-OG1 THR 1.428 0.020
CB-CG2 THR 1.519 0.033
N-CA THR 1.459 0.020
CA-C THR 1.525 0.026
C-O THR 1.229 0.019
CA-CB TRP 1.535 0.022
CB-CG TRP 1.498 0.018
CG-CD1 TRP 1.363 0.014
CG-CD2 TRP 1.432 0.017
CD1-NE1 TRP 1.375 0.017
NE1-CE2 TRP 1.371 0.013
CD2-CE2 TRP 1.409 0.012
CD2-CE3 TRP 1.399 0.015
CE2-CZ2 TRP 1.393 0.017
CE3-CZ3 TRP 1.380 0.017
CZ2-CH2 TRP 1.369 0.019
CZ3-CH2 TRP 1.396 0.016
N-CA TRP 1.459 0.020
CA-C TRP 1.525 0.026
C-O TRP 1.229 0.019
CA-CB TYR 1.535 0.022
CB-CG TYR 1.512 0.015
CG-CD1 TYR 1.387 0.013
CG-CD2 TYR 1.387 0.013
CD1-CE1 TYR 1.389 0.015
CD2-CE2 TYR 1.389 0.015
CE1-CZ TYR 1.381 0.013
CE2-CZ TYR 1.381 0.013
CZ-OH TYR 1.374 0.017
N-CA TYR 1.459 0.020
CA-C TYR 1.525 0.026
C-O TYR 1.229 0.019
CA-CB VAL 1.543 0.021
CB-CG1 VAL 1.524 0.021
CB-CG2 VAL 1.524 0.021
N-CA VAL 1.459 0.020
CA-C VAL 1.525 0.026
C-O VAL 1.229 0.019
-
Angle Residue Mean StdDev
N-CA-CB ALA 110.1 1.4
CB-CA-C ALA 110.1 1.5
N-CA-C ALA 111.0 2.7
CA-C-O ALA 120.1 2.1
N-CA-CB ARG 110.6 1.8
CB-CA-C ARG 110.4 2.0
CA-CB-CG ARG 113.4 2.2
CB-CG-CD ARG 111.6 2.6
CG-CD-NE ARG 111.8 2.1
CD-NE-CZ ARG 123.6 1.4
NE-CZ-NH1 ARG 120.3 0.5
NE-CZ-NH2 ARG 120.3 0.5
NH1-CZ-NH2 ARG 119.4 1.1
N-CA-C ARG 111.0 2.7
CA-C-O ARG 120.1 2.1
N-CA-CB ASN 110.6 1.8
CB-CA-C ASN 110.4 2.0
CA-CB-CG ASN 113.4 2.2
CB-CG-ND2 ASN 116.7 2.4
CB-CG-OD1 ASN 121.6 2.0
ND2-CG-OD1 ASN 121.9 2.3
N-CA-C ASN 111.0 2.7
CA-C-O ASN 120.1 2.1
N-CA-CB ASP 110.6 1.8
CB-CA-C ASP 110.4 2.0
CA-CB-CG ASP 113.4 2.2
CB-CG-OD1 ASP 118.3 0.9
CB-CG-OD2 ASP 118.3 0.9
OD1-CG-OD2 ASP 123.3 1.9
N-CA-C ASP 111.0 2.7
CA-C-O ASP 120.1 2.1
N-CA-CB CYS 110.8 1.5
CB-CA-C CYS 111.5 1.2
CA-CB-SG CYS 114.2 1.1
N-CA-C CYS 111.0 2.7
CA-C-O CYS 120.1 2.1
N-CA-CB GLU 110.6 1.8
CB-CA-C GLU 110.4 2.0
CA-CB-CG GLU 113.4 2.2
CB-CG-CD GLU 114.2 2.7
CG-CD-OE1 GLU 118.3 2.0
CG-CD-OE2 GLU 118.3 2.0
OE1-CD-OE2 GLU 123.3 1.2
N-CA-C GLU 111.0 2.7
CA-C-O GLU 120.1 2.1
N-CA-CB GLN 110.6 1.8
CB-CA-C GLN 110.4 2.0
CA-CB-CG GLN 113.4 2.2
CB-CG-CD GLN 111.6 2.6
CG-CD-OE1 GLN 121.6 2.0
CG-CD-NE2 GLN 116.7 2.4
OE1-CD-NE2 GLN 121.9 2.3
N-CA-C GLN 111.0 2.7
CA-C-O GLN 120.1 2.1
N-CA-C GLY 113.1 2.5
CA-C-O GLY 120.6 1.8
N-CA-CB HIS 110.6 1.8
CB-CA-C HIS 110.4 2.0
CA-CB-CG HIS 113.6 1.7
CB-CG-ND1 HIS 123.2 2.5
CB-CG-CD2 HIS 130.8 3.1
CG-ND1-CE1 HIS 108.2 1.4
ND1-CE1-NE2 HIS 109.9 2.2
CE1-NE2-CD2 HIS 106.6 2.5
NE2-CD2-CG HIS 109.2 1.9
CD2-CG-ND1 HIS 106.0 1.4
N-CA-C HIS 111.0 2.7
CA-C-O HIS 120.1 2.1
N-CA-CB ILE 110.8 2.3
CB-CA-C ILE 111.6 2.0
CA-CB-CG1 ILE 111.0 1.9
CB-CG1-CD1 ILE 113.9 2.8
CA-CB-CG2 ILE 110.9 2.0
CG1-CB-CG2 ILE 111.4 2.2
N-CA-C ILE 111.0 2.7
CA-C-O ILE 120.1 2.1
N-CA-CB LEU 110.4 2.0
CB-CA-C LEU 110.2 1.9
CA-CB-CG LEU 115.3 2.3
CB-CG-CD1 LEU 111.0 1.7
CB-CG-CD2 LEU 111.0 1.7
CD1-CG-CD2 LEU 110.5 3.0
N-CA-C LEU 111.0 2.7
CA-C-O LEU 120.1 2.1
N-CA-CB LYS 110.6 1.8
CB-CA-C LYS 110.4 2.0
CA-CB-CG LYS 113.4 2.2
CB-CG-CD LYS 111.6 2.6
CG-CD-CE LYS 111.9 3.0
CD-CE-NZ LYS 111.7 2.3
N-CA-C LYS 111.0 2.7
CA-C-O LYS 120.1 2.1
N-CA-CB MET 110.6 1.8
CB-CA-C MET 110.4 2.0
CA-CB-CG MET 113.3 1.7
CB-CG-SD MET 112.4 3.0
CG-SD-CE MET 100.2 1.6
N-CA-C MET 111.0 2.7
CA-C-O MET 120.1 2.1
N-CA-CB PHE 110.6 1.8
CB-CA-C PHE 110.4 2.0
CA-CB-CG PHE 113.9 2.4
CB-CG-CD1 PHE 120.8 0.7
CB-CG-CD2 PHE 120.8 0.7
CD1-CG-CD2 PHE 118.3 1.3
CG-CD1-CE1 PHE 120.8 1.1
CG-CD2-CE2 PHE 120.8 1.1
CD1-CE1-CZ PHE 120.1 1.2
CD2-CE2-CZ PHE 120.1 1.2
CE1-CZ-CE2 PHE 120.0 1.8
N-CA-C PHE 111.0 2.7
CA-C-O PHE 120.1 2.1
N-CA-CB PRO 103.3 1.2
CB-CA-C PRO 111.7 2.1
CA-CB-CG PRO 104.8 1.9
CB-CG-CD PRO 106.5 3.9
CG-CD-N PRO 103.2 1.5
CA-N-CD PRO 111.7 1.4
N-CA-C PRO 112.1 2.6
CA-C-O PRO 120.2 2.4
N-CA-CB SER 110.5 1.5
CB-CA-C SER 110.1 1.9
CA-CB-OG SER 111.2 2.7
N-CA-C SER 111.0 2.7
CA-C-O SER 120.1 2.1
N-CA-CB THR 110.3 1.9
CB-CA-C THR 111.6 2.7
CA-CB-OG1 THR 109.0 2.1
CA-CB-CG2 THR 112.4 1.4
OG1-CB-CG2 THR 110.0 2.3
N-CA-C THR 111.0 2.7
CA-C-O THR 120.1 2.1
N-CA-CB TRP 110.6 1.8
CB-CA-C TRP 110.4 2.0
CA-CB-CG TRP 113.7 1.9
CB-CG-CD1 TRP 127.0 1.3
CB-CG-CD2 TRP 126.6 1.3
CD1-CG-CD2 TRP 106.3 0.8
CG-CD1-NE1 TRP 110.1 1.0
CD1-NE1-CE2 TRP 109.0 0.9
NE1-CE2-CD2 TRP 107.3 1.0
CE2-CD2-CG TRP 107.3 0.8
CG-CD2-CE3 TRP 133.9 0.9
NE1-CE2-CZ2 TRP 130.4 1.1
CE3-CD2-CE2 TRP 118.7 1.2
CD2-CE2-CZ2 TRP 122.3 1.2
CE2-CZ2-CH2 TRP 117.4 1.0
CZ2-CH2-CZ3 TRP 121.6 1.2
CH2-CZ3-CE3 TRP 121.2 1.1
CZ3-CE3-CD2 TRP 118.8 1.3
N-CA-C TRP 111.0 2.7
CA-C-O TRP 120.1 2.1
N-CA-CB TYR 110.6 1.8
CB-CA-C TYR 110.4 2.0
CA-CB-CG TYR 113.4 1.9
CB-CG-CD1 TYR 121.0 0.6
CB-CG-CD2 TYR 121.0 0.6
CD1-CG-CD2 TYR 117.9 1.1
CG-CD1-CE1 TYR 121.3 0.8
CG-CD2-CE2 TYR 121.3 0.8
CD1-CE1-CZ TYR 119.8 0.9
CD2-CE2-CZ TYR 119.8 0.9
CE1-CZ-CE2 TYR 119.8 1.6
CE1-CZ-OH TYR 120.1 2.7
CE2-CZ-OH TYR 120.1 2.7
N-CA-C TYR 111.0 2.7
CA-C-O TYR 120.1 2.1
N-CA-CB VAL 111.5 2.2
CB-CA-C VAL 111.4 1.9
CA-CB-CG1 VAL 110.9 1.5
CA-CB-CG2 VAL 110.9 1.5
CG1-CB-CG2 VAL 110.9 1.6
N-CA-C VAL 111.0 2.7
CA-C-O VAL 120.1 2.1
-
Non-bonded distance Minimum Dist Tolerance
C-C 3.4 1.5
C-N 3.25 1.5
C-S 3.5 1.5
C-O 3.22 1.5
N-N 3.1 1.5
N-S 3.35 1.5
N-O 3.07 1.5
O-S 3.32 1.5
O-O 3.04 1.5
S-S 2.03 1.0
-
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline for model features."""
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List
from alphafold.common import residue_constants
from alphafold.data import msa_pairing
from alphafold.data import pipeline
import numpy as np
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]
) -> pipeline.FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[pipeline.FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> List[pipeline.FeatureDict]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: pipeline.FeatureDict,
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, pipeline.FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses the mmCIF file format."""
import collections
import dataclasses
import functools
import io
from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
from Bio import PDB
from Bio.Data import SCOPData
# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]
@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int
# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int
# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str
@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str
@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.
Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.
Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]
header = _get_header(parsed_info)
# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
else:
hetflag = 'H_' + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code)
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
return min(revision_dates)
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join([
experiment['_exptl.method'].lower() for experiment in experiments])
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])
header['resolution'] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high'):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
except ValueError:
logging.debug('Invalid resolution format: %s', parsed_info[res_key])
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num'])))
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type']
for monomer in seq_info]):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import collections
import functools
import string
from typing import Any, Dict, Iterable, List, Sequence
from alphafold.common import residue_constants
from alphafold.data import pipeline
import numpy as np
import pandas as pd
import scipy.linalg
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df
return species_lookup
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(examples: List[pipeline.FeatureDict]
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: pipeline.FeatureDict,
np_chains_list: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
dtype=np.int32)
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
dtype=np.int32)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list]
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
return np_example
def _pad_templates(chains: Sequence[pipeline.FeatureDict],
max_templates: int) -> Sequence[pipeline.FeatureDict]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[pipeline.FeatureDict],
pair_msa_sequences: bool) -> pipeline.FeatureDict:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(
chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
return chains
def _concatenate_paired_and_unpaired_features(
example: pipeline.FeatureDict) -> pipeline.FeatureDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
pair_msa_sequences: bool,
max_templates: int) -> pipeline.FeatureDict:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if tuple(seq) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for parsing various file formats."""
import collections
import dataclasses
import itertools
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
# Internal import (7716).
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file."""
sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Sequence[str]
def __post_init__(self):
if not (len(self.sequences) ==
len(self.deletion_matrix) ==
len(self.descriptions)):
raise ValueError(
'All fields for an MSA must have the same length. '
f'Got {len(self.sequences)} sequences, '
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
f'{len(self.descriptions)} descriptions.')
def __len__(self):
return len(self.sequences)
def truncate(self, max_seqs: int):
return Msa(sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs])
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
sum_probs: Optional[float]
query: str
hit_sequence: str
indices_query: List[int]
indices_hit: List[int]
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line
return sequences, descriptions
def parse_stockholm(stockholm_string: str) -> Msa:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ''
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
return Msa(sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys()))
def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return Msa(sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions)
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break
# Convert sto format to a3m line by line
a3m_sequences = {}
if remove_first_row_gaps:
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items():
# Dots are optional in a3m format and are commonly removed.
out_sequence = sto_sequence.replace('.', '')
if remove_first_row_gaps:
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
a3m_sequences[seqname] = out_sequence
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set()
filtered_lines = []
with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)
return ''.join(filtered_lines)
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[reference_annotation_i] = reference_annotation_line
if not any(mask): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(itertools.compress(alignment, mask))
processed_lines[line_index] = f'{prefix} {masked_alignment}'
# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)
# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment
seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)
filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n'
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch.
Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)
Returns:
A dictionary with the information from that detailed comparison section
Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]
# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x)
for x in match.groups()]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ''
hit_sequence = ''
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and
not line.startswith('Q ss_pred') and
not line.startswith('Q Consensus')):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '):
# Parse the hit sequence.
if (not line.startswith('T ss_dssp') and
not line.startswith('T ss_pred') and
not line.startswith('T Consensus')):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit)
return TemplateHit(
index=number_of_hit,
name=name_hit,
aligned_cols=int(aligned_cols),
sum_probs=sum_probs,
query=query,
hit_sequence=hit_sequence,
indices_query=indices_query,
indices_hit=indices_hit,
)
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()
# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices
@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str
def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip())
if not match:
raise ValueError(f'Could not parse description: "{description}".')
return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6])
def parse_hmmsearch_a3m(query_sequence: str,
a3m_string: str,
skip_first: bool = True) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]
indices_query = _get_indices(query_sequence, start=0)
hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)
return hits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment