Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
name: Build
on:
pull_request:
types: [synchronize, labeled]
jobs:
build:
name: Build and Test FastFold
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/FastFold' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/fastfold:/data/scratch/fastfold
timeout-minutes: 40
steps:
- uses: actions/checkout@v2
with:
repository: hpcaitech/FastFold
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install FastFold
run: |
[ ! -z "$(ls -A /github/home/fastfold_cache/)" ] && cp -r /github/home/fastfold_cache/* /__w/FastFold/FastFold/
pip install -r requirements/requirements.txt
pip install -e .
pip install -r requirements/test_requirements.txt
cp -r /__w/FastFold/FastFold/build /github/home/fastfold_cache/
cp /__w/FastFold/FastFold/*.so /github/home/fastfold_cache/
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
env:
NCCL_SHM_DISABLE: 1
name: Release bdist wheel
on:
workflow_dispatch:
inputs:
torch_version:
type: string
description: torch version, separated by comma
required: true
default: "all"
cuda_version:
type: string
description: cuda version, separated by comma
required: true
github_ref:
type: string
description: Branch or Tag
default: 'main'
required: true
jobs:
matrix_preparation:
name: Prepare Container List
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
env:
TORCH_VERSIONS: ${{ inputs.torch_version }}
CUDA_VERSIONS: ${{ inputs.cuda_version }}
run: |
echo $TORCH_VERSIONS
echo $CUDA_VERSIONS
IFS=','
DOCKER_IMAGE=()
for cv in $CUDA_VERSIONS
do
DOCKER_IMAGE+=("\"hpcaitech/cuda-conda:${cv}\"")
done
container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
container="[${container}]"
echo "$container"
echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
build:
name: Release bdist wheels
needs: matrix_preparation
if: github.repository == 'hpcaitech/FastFold' && contains(fromJson('["FrankLeeeee", "feifeibear", "Shenggan", "Gy-Lu"]'), github.actor)
runs-on: [self-hosted, gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Copy scripts and checkout
run: |
cp -r ./.github/workflows/scripts/build* ./
ln -s /github/home/pip_wheels ./pip_wheels
git checkout $git_ref
env:
git_ref: ${{ github.event.inputs.github_ref }}
- name: Build bdist wheel
run: |
pip install beautifulsoup4 requests packaging
python ./build_fastfold_wheel.py --torch_version $TORCH_VERSIONS
env:
TORCH_VERSIONS: ${{ inputs.torch_version }}
- name: 🚀 Deploy
uses: garygrossgarten/github-action-scp@release
with:
local: all_dist
remote: ${{ secrets.PRIVATE_PYPI_DIR }}
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
\ No newline at end of file
import requests
from bs4 import BeautifulSoup
import argparse
import os
import subprocess
from packaging import version
from functools import cmp_to_key
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
CUDA_HOME = os.environ['CUDA_HOME']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--torch_version', type=str)
return parser.parse_args()
def get_cuda_bare_metal_version():
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return bare_metal_major, bare_metal_minor
def all_wheel_info():
page_text = requests.get(WHEEL_TEXT_ROOT_URL).text
soup = BeautifulSoup(page_text)
all_a_links = soup.find_all('a')
wheel_info = dict()
for a_link in all_a_links:
if 'cuda' in a_link.text and '.txt' in a_link.text:
filename = a_link.text
torch_version, cuda_version = filename.rstrip('.txt').split('-')
cuda_version = cuda_version.lstrip('cuda')
if float(cuda_version) < 11.1:
continue
if torch_version not in wheel_info:
wheel_info[torch_version] = dict()
wheel_info[torch_version][cuda_version] = dict()
file_text = requests.get(f'{RAW_TEXT_FILE_PREFIX}/{filename}').text
lines = file_text.strip().split('\n')
for line in lines:
parts = line.split('\t')
method, url, python_version = parts[:3]
if float(python_version) < 3.8 or method == "conda":
continue
wheel_info[torch_version][cuda_version][python_version] = dict(url=url)
return wheel_info
def build_fastfold(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
for torch_version, cuda_versioned_wheel_info in wheel_info.items():
for cuda_version, python_versioned_wheel_info in cuda_versioned_wheel_info.items():
if cuda_version_on_host == cuda_version:
for python_version, wheel_info in python_versioned_wheel_info.items():
url = wheel_info['url']
filename = url.split('/')[-1].replace('%2B', '+')
cmd = f'bash ./build_fastfold_wheel.sh {url} {filename} {cuda_version} {python_version}'
os.system(cmd)
def main():
args = parse_args()
wheel_info = all_wheel_info()
# filter wheels on condition
all_torch_versions = list(wheel_info.keys())
def _compare_version(a, b):
if version.parse(a) > version.parse(b):
return 1
else:
return -1
all_torch_versions.sort(key=cmp_to_key(_compare_version))
if args.torch_version != 'all':
torch_versions = args.torch_version.split(',')
# only keep the torch versions specified
for key in all_torch_versions:
if key not in torch_versions:
wheel_info.pop(key)
build_fastfold(wheel_info)
if __name__ == '__main__':
main()
#!/usr/bin/env bash
url=${1}
filename=${2}
cuda_version=${3}
python_version=${4}
git reset --hard HEAD
mkdir -p ./all_dist
source activate base
conda create -n $python_version -y python=$python_version
source activate $python_version
wget -nc -q -O ./$filename $url
pip install ./$filename
pip install numpy
python setup.py bdist_wheel
mv ./dist/* ./all_dist
python setup.py clean
conda deactivate
conda env remove -n $python_version
\ No newline at end of file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# vscode
.vscode/
# setup
dist/
build/
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021- HPC-AI Technology Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## Some of fastfold's code is derived from others projects, which is subject to the following copyright notice:
---------------- LICENSE FOR NVIDIA Apex ----------------
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR Aqlaboratory Openfold ----------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
![](/assets/fold.jpg)
# FastFold
[![](https://img.shields.io/badge/Paper-PDF-green?style=flat&logo=arXiv&logoColor=green)](https://arxiv.org/abs/2203.00854)
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)
![](https://img.shields.io/badge/Habana-support-blue?style=flat&logo=intel&logoColor=blue)
![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold)](https://github.com/hpcaitech/FastFold/blob/main/LICENSE)
## News :triangular_flag_on_post:
- [2023/01] Compatible with AlphaFold v2.3
- [2023/01] Added support for inference and training of AlphaFold on [Intel Habana](https://habana.ai/) platform. For usage instructions, see [here](#Inference-or-Training-on-Intel-Habana).
<br>
Optimizing Protein Structure Prediction Model Training and Inference on Heterogeneous Clusters
FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.
1. Excellent kernel performance on GPU platform
2. Supporting Dynamic Axial Parallelism(DAP)
* Break the memory limit of single GPU and reduce the overall training time
* DAP can significantly speed up inference and make ultra-long sequence inference possible
3. Ease of use
* Huge performance gains with a few lines changes
* You don't need to care about how the parallel part is implemented
4. Faster data processing, about 3x times faster on monomer, about 3Nx times faster on multimer with N sequence.
5. Great Reduction on GPU memory, able to inference sequence containing more than **10000** residues.
## Installation
To install FastFold, you will need:
+ Python 3.8 or 3.9.
+ [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.3 or above
+ PyTorch 1.12 or above
For now, You can install FastFold:
### Using Conda (Recommended)
We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda.
Lines below would create a new conda environment called "fastfold":
```shell
git clone https://github.com/hpcaitech/FastFold
cd FastFold
conda env create --name=fastfold -f environment.yml
conda activate fastfold
python setup.py install
```
#### Advanced
To leverage the power of FastFold, we recommend you to install [Triton](https://github.com/openai/triton).
**NOTE: Triron needs CUDA 11.4 to run.**
```bash
pip install -U --pre triton
```
## Use Docker
### Build On Your Own
Run the following command to build a docker image from Dockerfile provided.
> Building FastFold from scratch requires GPU support, you need to use Nvidia Docker Runtime as the default when doing `docker build`. More details can be found [here](https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime).
```shell
cd FastFold
docker build -t fastfold ./docker
```
Run the following command to start the docker container in interactive mode.
```shell
docker run -ti --gpus all --rm --ipc=host fastfold bash
```
## Usage
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model.fastnn import Evoformer`:
```python
from fastfold.model.fastnn import Evoformer
evoformer_layer = Evoformer()
```
If you want to use Dynamic Axial Parallelism, add a line of initialize with `fastfold.distributed.init_dap`.
```python
from fastfold.distributed import init_dap
init_dap(args.dap_size)
```
### Download the dataset
You can down the dataset used to train FastFold by the script `download_all_data.sh`:
./scripts/download_all_data.sh data/
### Inference
You can use FastFold with `inject_fastnn`. This will replace the evoformer from OpenFold with the high performance evoformer from FastFold.
```python
from fastfold.utils import inject_fastnn
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_fastnn(model)
```
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir .outputs/ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--enable_workflow \
--inplace
```
or run the script `./inference.sh`, you can change the parameter in the script, especisally those data path.
```shell
./inference.sh
```
Alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, which achieves a 3x times faster speed. To run the inference with ray workflow, we add parameter `--enable_workflow` by default.
To reduce memory usage of embedding presentations, we also add parameter `--inplace` to share memory by defaul.
#### inference with lower memory usage
Alphafold's embedding presentations take up a lot of memory as the sequence length increases. To reduce memory usage,
you should add parameter `--chunk_size [N]` to cmdline or shell script `./inference.sh`.
The smaller you set N, the less memory will be used, but it will affect the speed. We can inference
a sequence of length 10000 in bf16 with 61GB memory on a Nvidia A100(80GB). For fp32, the max length is 8000.
> You need to set `PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:15000` to inference such an extreme long sequence.
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir .outputs/ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--enable_workflow \
--inplace
--chunk_size N \
```
#### inference multimer sequence
Alphafold Multimer is supported. You can the following cmd or shell script `./inference_multimer.sh`.
Workflow and memory parameters mentioned above can also be used.
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--gpus 2 \
--model_preset multimer \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniprot_database_path data/uniprot/uniprot.fasta \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--param_path data/params/params_model_1_multimer.npz \
--model_name model_1_multimer \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign`
```
### Inference or Training on Intel Habana
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments, and please use SynapseAI R1.7.1 to test as it was verified internally.
Once you have prepared your dataset and installed fastfold, you can use the following scripts:
```shell
cd fastfold/habana/fastnn/custom_op/; python setup.py build (this is for Gaudi, for Gaudi2 please use setup2.py) ; cd -
bash habana/inference.sh
bash habana/train.sh
```
## Performance Benchmark
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
```shell
cd ./benchmark
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
```
Benchmark Dynamic Axial Parallelism with 2 GPUs:
```shell
cd ./benchmark
torchrun --nproc_per_node=2 perf.py --msa-length 128 --res-length 256 --dap-size 2
```
If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:
```shell
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```
## Cite us
Cite this paper, if you use FastFold in your research publication.
```
@misc{cheng2022fastfold,
title={FastFold: Reducing AlphaFold Training Time from 11 Days to 67 Hours},
author={Shenggan Cheng and Ruidong Wu and Zhongming Yu and Binrui Li and Xiwen Zhang and Jian Peng and Yang You},
year={2022},
eprint={2203.00854},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
## Acknowledgments
We would like to extend our special thanks to the Intel Habana team for their support in providing us with technology and resources on the Habana platform.
import argparse
import os
import torch
import torch.nn as nn
from fastfold.distributed import init_dap
from fastfold.model.fastnn import Evoformer
def main():
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int, help='batch size')
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
parser.add_argument('--res-length',
default=256,
type=int,
help='Sequence Length of Residues')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers',
default=12,
type=int,
help='Evoformer Layers to Execute')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='Benchmark with Evoformer Implementation from OpenFold.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='run with profiler.')
args = parser.parse_args()
init_dap(args.dap_size)
precision = torch.bfloat16
if args.dap_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
if args.openfold:
from openfold.model.evoformer import EvoformerBlock
class OpenFoldEvoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(OpenFoldEvoformer, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c_hidden_msa_att = int(d_node / 8)
self.c_hidden_pair_att = int(d_pair / 8)
self.EvoformerBlock = EvoformerBlock(c_m=d_node,
c_z=d_pair,
c_hidden_msa_att=self.c_hidden_msa_att,
c_hidden_opm=self.c_hidden_msa_att,
c_hidden_mul=self.d_pair,
c_hidden_pair_att=self.c_hidden_pair_att,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
inf=1e9,
eps=1e-10)
def forward(self, node, pair, node_mask, pair_mask):
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
return node, pair
attn_layers = []
for idx in range(0, args.layers):
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)
start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials):
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
inputs_node = torch.randn(args.batch_size,
args.msa_length // args.dap_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.dap_size,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)
if args.prof:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1,
warmup=args.warmup_trials,
active=args.trials,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
profile_memory=False,
record_shapes=False,
with_stack=False)
prof.start()
for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials
torch.distributed.barrier()
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_fwd[evt_idx].record()
for lyr_idx in range(0, args.layers):
layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_bwd[evt_idx].record()
if not args.fwd:
layer_inputs[1].backward(grads_node)
if evt_idx >= 0:
stop_evt_bwd[evt_idx].record()
if args.prof:
prof.step()
if args.prof:
prof.stop()
torch.distributed.barrier()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials):
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.batch_size, args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
if __name__ == '__main__':
main()
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import fastfold
import numpy as np
import torch
import torch.multiprocessing as mp
from fastfold.config import model_config
from fastfold.data import data_transforms
from fastfold.model.fastnn import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.tensor_utils import tensor_tree_map
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
torch.backends.cuda.matmul.allow_tf32 = True
def random_template_feats(n_templ, n):
b = []
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_mask": np.random.randint(0, 2, (*b, n_templ, n, 37)),
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3) * 10,
"template_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2),
"template_alt_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2),
"template_torsion_angles_mask": np.random.rand(*b, n_templ, n, 7),
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n):
b = []
batch = {
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64),
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(np.float32),
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
}
return batch
def generate_batch(n_res):
batch = {}
tf = torch.randint(21, size=(n_res,))
batch["target_feat"] = torch.nn.functional.one_hot(tf, 22).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((128, n_res, 49))
t_feats = random_template_feats(4, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(5120, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(low=0, high=2, size=(128, n_res)).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: (t.unsqueeze(-1).expand(*t.shape, 3))
batch = tensor_tree_map(add_recycling_dims, batch)
return batch
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
torch.cuda.set_device(rank)
config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
config.globals.inplace = args.inplace
config.globals.is_multimer = False
model = AlphaFold(config)
model = inject_fastnn(model)
model = model.eval()
model = model.cuda()
set_chunk_size(model.globals.chunk_size)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
torch.cuda.synchronize()
def inference_monomer_model(args):
batch = generate_batch(args.n_res)
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
# get unrelexed pdb and save
# batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
# plddt = out["plddt"]
# plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
# unrelaxed_protein = protein.from_prediction(features=batch,
# result=out,
# b_factors=plddt_b_factors)
# with open('demo_unrelex.pdb', 'w+') as fp:
# fp.write(unrelaxed_protein)
def main(args):
inference_monomer_model(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=1, help="""Number of GPUs with which to run inference""")
parser.add_argument("--n_res", type=int, default=50, help="virtual residue number of random data")
parser.add_argument("--model_name", type=str, default="model_1", help="model name of alphafold")
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--inplace', default=False, action='store_true')
args = parser.parse_args()
main(args)
\ No newline at end of file
FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0
RUN conda install openmm=7.7.0 pdbfixer -c conda-forge -y \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda -y
RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 \
scipy==1.7.1 ray pyarrow pandas einops
RUN pip install colossalai
Run git clone https://github.com/hpcaitech/FastFold.git \
&& cd ./FastFold \
&& python setup.py install
name: fastfold
channels:
- conda-forge
- bioconda
- pytorch
dependencies:
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==4.3.0
- einops
- ray==2.0.0
- pyarrow
- pandas
- colossalai==0.2.7
- pytorch::pytorch=1.12
- pytorch::torchvision
- pytorch::torchaudio
- conda-forge::cudatoolkit=11.3
- conda-forge::python=3.8
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.7.0
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
VERSION = "0.1.0-beta"
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.common import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if(seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(
res_symbol, residue_constants.restype_num
) for res_symbol in seq
])
elif("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (
np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM
elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(len(mask), residue_constants.atom_type_num,)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1")
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
# Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = False,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import os
import urllib.request
import collections
import functools
from typing import Mapping, List, Tuple
import numpy as np
import tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"N",
"O",
"OH",
],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple(
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
)
BondAngle = collections.namedtuple(
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
def get_cache_path():
cache_path = os.path.join(os.path.expanduser("~"), '.fastfold')
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
return cache_path
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
stereo_chemical_props_path = os.path.join(get_cache_path(), 'stereo_chemical_props.txt')
if not os.path.exists(stereo_chemical_props_path):
url = "https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt"
urllib.request.urlretrieve(url, stereo_chemical_props_path)
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev))
)
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(
bond1.length ** 2
+ bond2.length ** 2
- 2 * bond1.length * bond2.length * np.cos(gamma)
)
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (
2 * bond1.length * bond2.length * np.sin(gamma)
) * dl_outer
dl_db1 = (
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
) * dl_outer
dl_db2 = (
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2
+ (dl_db1 * bond1.stddev) ** 2
+ (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev)
)
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"ND2",
"",
"",
"",
"",
"",
"",
],
"ASP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"OD2",
"",
"",
"",
"",
"",
"",
],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"CD1",
"",
"",
"",
"",
"",
"",
],
"LEU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"",
"",
"",
"",
"",
"",
],
"LYS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"CE",
"NZ",
"",
"",
"",
"",
"",
],
"MET": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"SD",
"CE",
"",
"",
"",
"",
"",
"",
],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": [
"N",
"CA",
"C",
"O",
"CB",
"OG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 "
"without any gaps. Got: %s" % sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(
f"Invalid character in the sequence: {aa_type}"
)
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree.map_structure(
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
)
chi_angles_atom_indices = np.array(
[
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack(
[ex_normalized, ey_normalized, eznorm, translation]
).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname
]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[
restype, atomtype, :
] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[
restype, atom14idx, :
] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: np.array(pos)
for name, _, pos in rigid_group_atom_positions[resname]
}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [
atom_positions[name] for name in base_atom_names
]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[
restype, 4 + chi_idx, :, :
] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(
overlap_tolerance=1.5, bond_length_tolerance_factor=15
):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom1_idx
] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom2_idx
] = atom1_idx
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype):
return ''.join([
restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
### ALPHAFOLD MULTIMER STUFF ###
def _make_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return np.array(chi_atom_indices)
def _make_renaming_matrices():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3 = [
restype_1to3[res] for res in restypes
]
restype_3 += ['UNK']
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
def _make_restype_atom37_mask():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
return restype_atom37_mask
def _make_restype_atom14_mask():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_mask
def _make_restype_atom37_to_atom14():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
])
restype_atom37_to_atom14.append([0] * 37)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
return restype_atom37_to_atom14
def _make_restype_atom14_to_atom37():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
for rt in restypes:
atom_names = restype_name_to_atom14_names[
restype_1to3[rt]]
restype_atom14_to_atom37.append([
(atom_order[name] if name else 0)
for name in atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
return restype_atom14_to_atom37
def _make_restype_atom14_is_ambiguous():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[
restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
def _make_restype_rigidgroup_base_atom37_idx():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names = np.full([21, 8, 3], '', dtype=object)
# 0: backbone frame
base_atom_names[:, 0, :] = ['C', 'CA', 'N']
# 3: 'psi-group'
base_atom_names[:, 3, :] = ['CA', 'C', 'O']
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for chi_idx in range(4):
if chi_angles_mask[restype][chi_idx]:
atom_names = chi_angles_atoms[resname][chi_idx]
base_atom_names[restype, chi_idx + 4, :] = atom_names[1:]
# Translate atom names into atom37 indices.
lookuptable = atom_order.copy()
lookuptable[''] = 0
restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])(
base_atom_names)
return restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES = _make_chi_atom_indices()
RENAMING_MATRICES = _make_renaming_matrices()
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
RESTYPE_ATOM37_MASK = _make_restype_atom37_mask()
RESTYPE_ATOM14_MASK = _make_restype_atom14_mask()
RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32)
RESTYPE_RIGIDGROUP_MASK[:, 0] = 1
RESTYPE_RIGIDGROUP_MASK[:, 3] = 1
RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask
# modified from openfold/openfold/config.py
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import ml_collections as mlc
def set_inf(c, inf):
for k, v in c.items():
if isinstance(v, mlc.ConfigDict):
set_inf(v, inf)
elif k == "inf":
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_2":
# AF2 Suppl. Table 5, Model 1.1.2
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_2_ptm":
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_5_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "relax":
pass
elif "multimer" in name:
c.globals.is_multimer = True
c.data.predict.max_msa_clusters = 252 # 128 for monomer
c.model.structure_module.trans_scale_factor = 20 # 10 for monomer
for k, v in multimer_model_config_update.items():
c.model[k] = v
c.data.common.unsupervised_features.extend(
[
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
)
else:
raise ValueError("Invalid model name")
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
return c
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict(
{
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
"atom14_alt_gt_exists": [NUM_RES, None],
"atom14_alt_gt_positions": [NUM_RES, None, None],
"atom14_atom_exists": [NUM_RES, None],
"atom14_atom_is_ambiguous": [NUM_RES, None],
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
"is_distillation": [],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
"rigidgroups_group_exists": [NUM_RES, None],
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
"rigidgroups_gt_exists": [NUM_RES, None],
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
"seq_length": [],
"seq_mask": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
"template_sum_probs": [NUM_TEMPLATES, None],
"template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
"is_distillation",
],
},
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": True,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
"distillation_prob": 0.75,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
},
},
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
"is_multimer": False,
},
"model": {
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps, # 1e-6,
"weight": 0.0,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"enabled": tm_enabled,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
}
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import json
import logging
import os
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import torch
from colossalai.utils import is_using_ddp
from fastfold.data import (
data_pipeline,
feature_pipeline,
mmcif_parsing,
templates,
)
from fastfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
if(template_release_dates_cache_path is None):
logging.warning(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if(mmcif_object.mmcif_object is None):
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
def chain_id_to_idx(self, chain_id):
return self._chain_id_to_idx_dict[chain_id]
def idx_to_chain_id(self, idx):
return self._chain_ids[idx]
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
file_id, chain_id = spl
else:
file_id, = spl
chain_id = None
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb",
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
data, self.mode
)
return feats
def __len__(self):
return len(self._chain_ids)
def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
counts[aa] += 1
largest_aa_count = max(counts.values())
largest_single_aa_prop = largest_aa_count / len(seq)
if(largest_single_aa_prop > max_single_aa_prop):
return False
return True
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
out = 1
for p in probabilities:
out *= p
return out
class OpenFoldDataset(torch.utils.data.Dataset):
"""
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
# Uniformly shuffle each dataset's indices
weights = [1. for _ in range(dataset_len)]
shuf = torch.multinomial(
torch.tensor(weights),
num_samples=dataset_len,
replacement=False,
generator=self.generator,
)
for idx in shuf:
yield idx
def looped_samples(dataset_idx):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init):
self.reroll()
def __getitem__(self, idx):
dataset_idx, datapoint_idx = self.datapoints[idx]
return self.datasets[dataset_idx][datapoint_idx]
def __len__(self):
return self.epoch_len
def reroll(self):
dataset_choices = torch.multinomial(
torch.tensor(self.probabilities),
num_samples=self.epoch_len,
replacement=True,
generator=self.generator,
)
self.datapoints = []
for dataset_idx in dataset_choices:
samples = self._samples[dataset_idx]
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))
class OpenFoldBatchCollator:
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
# By this stack, the batch dimension is processed and added.
stack_fn = partial(torch.stack, dim=0)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return dict_multimap(stack_fn, processed_prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, config, stage="train", generator=None, **kwargs):
super().__init__(dataset, **kwargs)
self.config = config
self.stage = stage
if(generator is None):
generator = torch.Generator()
self.generator = generator
self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def _add_batch_properties(self, batch):
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if(key == "no_recycling_iters"):
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
return batch
def __iter__(self):
it = super().__iter__()
def _batch_prop_gen(iterator):
for batch in iterator:
yield self._add_batch_properties(batch)
return _batch_prop_gen(it)
def SetupTrainDataset(
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs,
):
if(train_data_dir is None or train_alignment_dir is None):
raise ValueError(
'train_data_dir and train_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
_alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
_alignment_index = json.load(fp)
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
config=config,
kalign_binary_path=kalign_binary_path,
template_release_dates_cache_path=
template_release_dates_cache_path,
obsolete_pdbs_file_path=
obsolete_pdbs_file_path,
)
train_dataset = dataset_gen(
data_dir=train_data_dir,
alignment_dir=train_alignment_dir,
mapping_path=train_mapping_path,
max_template_hits=config.train.max_template_hits,
shuffle_top_k_prefiltered=
config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=_alignment_index,
)
distillation_dataset = None
if(distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=distillation_data_dir,
alignment_dir=distillation_alignment_dir,
mapping_path=distillation_mapping_path,
max_template_hits=config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
)
d_prob = config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
train_chain_data_cache_path,
distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
train_chain_data_cache_path,
]
train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False,
)
if(val_data_dir is not None):
eval_dataset = dataset_gen(
data_dir=val_data_dir,
alignment_dir=val_alignment_dir,
mapping_path=None,
max_template_hits=config.eval.max_template_hits,
mode="eval",
_output_raw=True,
)
else:
eval_dataset = None
return train_dataset, eval_dataset
def TrainDataLoader(
config: mlc.ConfigDict,
train_dataset: torch.utils.data.Dataset,
test_dataset: Optional[torch.utils.data.Dataset] = None,
batch_seed: Optional[int] = None,
):
if not config.data_module.data_loaders.batch_size == 1:
raise ValueError("Only support batch size equals to 1")
generator = torch.Generator()
if(batch_seed is not None):
generator = generator.manual_seed(batch_seed)
train_batch_collator = OpenFoldBatchCollator(config, "train")
train_sampler = None
if is_using_ddp():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataset.reroll()
train_dataloader = OpenFoldDataLoader(
dataset=train_dataset,
config=config,
stage="train",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=train_batch_collator,
sampler=train_sampler,
)
test_dataloader = None
if test_dataset is not None:
test_sampler = None
if is_using_ddp():
test_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
test_batch_collator = OpenFoldBatchCollator(config, "test")
test_dataloader = OpenFoldDataLoader(
dataset=test_dataset,
config=config,
stage="test",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=test_batch_collator,
sampler=test_sampler,
)
return train_dataloader, test_dataloader
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import collections
import contextlib
import dataclasses
import datetime
import json
import copy
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
from fastfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from fastfold.data import templates
from fastfold.data.parsers import Msa
from fastfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from fastfold.data.tools.utils import to_date
from fastfold.common import residue_constants, protein
FeatureDict = Mapping[str, np.ndarray]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Union[templates.TemplateHitFeaturizer, templates.HmmsearchHitFeaturizer],
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
if type(template_featurizer) == templates.TemplateHitFeaturizer:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
query_release_date=query_release_date,
hits=hits_cat,
)
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
hits=hits_cat,
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
return mmcif_feats
def _aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
def make_protein_features(
protein_object: protein.Protein,
description: str,
_is_distillation: bool = False,
) -> FeatureDict:
pdb_feats = {}
aatype = protein_object.aatype
sequence = _aatype_to_str_sequence(aatype)
pdb_feats.update(
make_sequence_features(
sequence=sequence,
description=description,
num_res=len(protein_object.aatype),
)
)
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(
1. if _is_distillation else 0.
).astype(np.float32)
return pdb_feats
def make_pdb_features(
protein_object: protein.Protein,
description: str,
confidence_threshold: float = 0.5,
is_distillation: bool = True,
) -> FeatureDict:
pdb_feats = make_protein_features(
protein_object, description, _is_distillation=True
)
if(is_distillation):
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
for i, confident in enumerate(high_confidence):
if(not confident):
pdb_feats["all_atom_mask"][i] = 0
return pdb_feats
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]
)
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
return features
def run_msa_tool(
msa_runner,
fasta_path: str,
msa_out_path: str,
msa_format: str,
max_sto_sequences: Optional[int] = None,
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if(msa_format == "sto"):
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else:
result = msa_runner.query(fasta_path)
with open(msa_out_path, "w") as f:
f.write(result[msa_format])
return result
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hhsearch_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniref30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniref30_database_path:
Path to uniref30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniref30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hhsearch": {
"binary": hhsearch_binary_path,
"dbs": [
pdb70_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
and uniref90_database_path is None):
raise ValueError(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniref_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniref30_database_path is not None):
dbs.append(uniref30_database_path)
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = None
if(pdb70_database_path is not None):
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
)
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result["sto"],
max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
if(self.hhsearch_pdb70_runner is not None):
hhsearch_result = self.hhsearch_pdb70_runner.query(
uniref90_msa_as_a3m
)
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if(self.jackhmmer_mgnify_runner is not None):
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result["sto"],
max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
elif(self.hhblits_bfd_uniref_runner is not None):
hhblits_bfd_uniref_result = (
self.hhblits_bfd_uniref_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniref_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniref_result["a3m"])
class AlignmentRunnerMultimer:
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
hmmsearch_binary_path: Optional[str] = None,
hmmbuild_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniref30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
pdb_seqres_database_path: Optional[str] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniref30_database_path:
Path to uniref30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniref30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
"hmmsearch": {
"binary": hmmsearch_binary_path,
"dbs": [
pdb_seqres_database_path,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniref_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniref30_database_path is not None):
dbs.append(uniref30_database_path)
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_uniprot_runner = None
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
)
self.hmmsearch_pdb_runner = None
if(pdb_seqres_database_path is not None):
self.hmmsearch_pdb_runner = hmmsearch.Hmmsearch(
binary_path=hmmsearch_binary_path,
hmmbuild_binary_path=hmmbuild_binary_path,
database_path=pdb_seqres_database_path,
)
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
fasta_path=fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
max_sto_sequences=self.uniref_max_hits,
)
template_msa = jackhmmer_uniref90_result["sto"]
template_msa = parsers.deduplicate_stockholm_msa(template_msa)
template_msa = parsers.remove_empty_columns_from_stockholm_msa(
template_msa
)
if(self.hmmsearch_pdb_runner is not None):
pdb_templates_result = self.hmmsearch_pdb_runner.query(
template_msa,
output_dir=output_dir
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
fasta_path=fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
max_sto_sequences=self.mgnify_max_hits
)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="sto",
)
elif(self.hhblits_bfd_uniref_runner is not None):
bfd_out_path = os.path.join(output_dir, "bfd_uniref_hits.a3m")
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="a3m",
)
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def convert_monomer_features(
monomer_features: FeatureDict,
chain_id: str
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'
}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
converted[feature_name] = feature
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)
def add_assembly_features(
all_chain_features: MutableMapping[str, FeatureDict],
) -> MutableMapping[str, FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_id, chain_features in all_chain_features.items():
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
for sym_id, chain_features in enumerate(group_chain_features, start=1):
new_all_chain_features[
f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
seq_length = chain_features['seq_length']
chain_features['asym_id'] = (
chain_id * np.ones(seq_length)
).astype(np.int64)
chain_features['sym_id'] = (
sym_id * np.ones(seq_length)
).astype(np.int64)
chain_features['entity_id'] = (
entity_id * np.ones(seq_length)
).astype(np.int64)
chain_id += 1
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
return np_example
class DataPipeline:
"""Assembles input features."""
def __init__(
self,
template_featurizer: Optional[templates.TemplateHitFeaturizer],
):
self.template_featurizer = template_featurizer
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
filename, ext = os.path.splitext(name)
if(ext == ".a3m"):
msa = parsers.parse_a3m(
read_msa(start, size)
)
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
msa = parsers.parse_stockholm(
read_msa(start, size)
)
else:
continue
msa_data[name] =msa
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue
msa_data[f] = msa
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
input_sequence: str=None,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
elif(name == "hmmsearch_output.sto"):
hits = parsers.parse_hmmsearch_sto(
read_template(start, size),
input_sequence,
)
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
elif(f == "hmm_output.sto"):
with open(path, "r") as fp:
hits = parsers.parse_hmmsearch_sto(
fp.read(),
input_sequence,
)
all_hits[f] = hits
return all_hits
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
"""
If the alignment dir contains no MSAs, an input sequence
must be provided.
"""
)
msa_data["dummy"] = Msa(
[input_sequence],
[[0 for _ in input_sequence]],
["dummy"]
)
msa_features = make_msa_features(list(msa_data.values()))
return msa_features
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index,
)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
**msa_features,
**template_features
}
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
def process_pdb(
self,
pdb_path: str,
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
if(_structure_index is not None):
db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation=is_distillation
)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
def process_core(
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
"""
with open(core_path, 'r') as f:
core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features}
class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: DataPipeline,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self._monomer_data_pipeline = monomer_data_pipeline
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
chain_alignment_dir: str,
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(
chain_fasta_path,
chain_alignment_dir
)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
feats = {
f'{k}_all_seq': v for k, v in all_seq_features.items()
if k in valid_feats
}
return feats
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
for desc, seq in zip(input_descs, input_seqs):
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment