Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
cff-version: 1.2.0 cff-version: 1.2.0
message: "For now, cite OpenFold with its DOI." preferred-citation:
authors: authors:
- family-names: "Ahdritz" - family-names: "Ahdritz"
given-names: "Gustaf" given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324 orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta" - family-names: "Bouatta"
given-names: "Nazim" given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan" - family-names: "Kadyan"
given-names: "Sachin" given-names: "Sachin"
- family-names: "Xia" orcid: https://orcid.org/0000-0002-6079-7627
- family-names: "Xia"
given-names: "Qinghui" given-names: "Qinghui"
- family-names: "Gerecke" - family-names: "Gerecke"
given-names: "William" given-names: "William"
- family-names: "AlQuraishi" orcid: https://orcid.org/0000-0002-9777-6192
- family-names: "O'Donnell"
given-names: "Timothy J"
orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed" given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322 orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold" title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
doi: 10.5281/zenodo.5709539 type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12 date-released: 2021-11-12
url: "https://github.com/aqlaboratory/openfold" url: "https://doi.org/10.1101/2022.11.20.517210"
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
RUN apt-get update && apt-get install -y wget cuda-minimal-build-10-2 git # metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
......
![header ](imgs/OpenFold_viz_banner.jpg) ![header ](imgs/of_banner.png)
_Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
# OpenFold # OpenFold
A faithful PyTorch reproduction of DeepMind's A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold). [AlphaFold 2](https://github.com/deepmind/alphafold).
## Features ## Features
...@@ -14,37 +16,41 @@ DeepMind experiments. It is omitted here for the sake of reducing clutter. In ...@@ -14,37 +16,41 @@ DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is built to support inference with AlphaFold's original JAX weights. OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
It's also faster than the official code on GPU. Try it out for yourself with and we've trained it from scratch, matching the performance of the original.
our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb). We've publicly released model weights and our training data — some 400,000
MSAs and PDB70 template hit files — under a permissive license. Model weights
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained are available via scripts in this repository while the MSAs are hosted by the
with [DeepSpeed](https://github.com/microsoft/deepspeed) and with either `fp16` [Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
or `bfloat16` half-precision. Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
OpenFold is equipped with an implementation of low-memory attention OpenFold also supports inference using AlphaFold's official parameters, and
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)), which vice versa (see `scripts/convert_of_weights_to_jax.py`).
enables inference on extremely long chains.
OpenFold has the following advantages over the reference implementation:
We've modified [FastFold](https://github.com/hpcaitech/FastFold)'s custom CUDA
kernels to support in-place attention during inference and training. These use - **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
kernels support in-place attention during inference and training. They use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch 4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively. implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
We also make available efficient scripts for generating alignments. We've - **FlashAttention** support greatly speeds up MSA attention.
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
## Installation (Linux) ## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite), alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)} and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. Finally, some download scripts require `aria2c`. installed on on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`.
For convenience, we provide a script that installs Miniconda locally, creates a For convenience, we provide a script that installs Miniconda locally, creates a
`conda` virtual environment, installs all Python dependencies, and downloads `conda` virtual environment, installs all Python dependencies, and downloads
useful resources (including DeepMind's pretrained parameters). Run: useful resources, including both sets of model parameters. Run:
```bash ```bash
scripts/install_third_party_dependencies.sh scripts/install_third_party_dependencies.sh
...@@ -76,14 +82,9 @@ To install the HH-suite to `/usr/bin`, run ...@@ -76,14 +82,9 @@ To install the HH-suite to `/usr/bin`, run
## Usage ## Usage
To download DeepMind's pretrained parameters and common ground truth data, run: If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use
```bash DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
bash scripts/download_data.sh data/
```
You have two choices for downloading protein databases, depending on whether
you want to use DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
MMseqs2 instead. For the former, run: MMseqs2 instead. For the former, run:
...@@ -102,12 +103,25 @@ Make sure to run the latter command on the machine that will be used for MSA ...@@ -102,12 +103,25 @@ Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system). MMseqs2 should be split according to the memory available on the system).
If you're using your own precomputed MSAs or MSAs from the RODA repository,
there's no need to download these alignment databases. Simply make sure that
the `alignment_dir` contains one directory per chain and that each of these
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain. You
can use `scripts/flatten_roda.sh` to reformat RODA downloads in this way.
Note that the RODA alignments are NOT compatible with the recent .cif ground
truth files downloaded by `scripts/download_alphafold_dbs.sh`. To fetch .cif
files that match the RODA MSAs, once the alignments are flattened, use
`scripts/download_roda_pdbs.sh`. That script outputs a list of alignment dirs
for which matching .cif files could not be found. These should be removed from
the alignment directory.
Alternatively, you can use raw MSAs from Alternatively, you can use raw MSAs from
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading [ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
the database, use `scripts/prep_proteinnet_msas.py` to convert the data into that database, use `scripts/prep_proteinnet_msas.py` to convert the data
a format recognized by the OpenFold parser. The resulting directory becomes the into a format recognized by the OpenFold parser. The resulting directory
`alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to becomes the `alignment_dir` used in subsequent steps. Use
extract `.core` files from ProteinNet text files. `scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
files.
For both inference and training, the model's hyperparameters can be tuned from For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using `openfold/config.py`. Of course, if you plan to perform inference using
...@@ -122,7 +136,7 @@ pretrained parameters, run e.g.: ...@@ -122,7 +136,7 @@ pretrained parameters, run e.g.:
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
target.fasta \ fasta_dir \
data/pdb_mmcif/mmcif_files/ \ data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
...@@ -130,22 +144,93 @@ python3 run_pretrained_openfold.py \ ...@@ -130,22 +144,93 @@ python3 run_pretrained_openfold.py \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \ --output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:1 \ --model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
``` ```
where `data` is the same directory as in the previous step. If `jackhmmer`, where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of `hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped. `/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here. skip the expensive alignment computation here with
`--use_precomputed_alignments`.
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
respectively. For a breakdown of the differences between the different parameter
files, see the README downloaded to `openfold/resources/openfold_params/`. Since
OpenFold was trained under a newer training schedule than the one from which the
`model_n` config presets are derived, there is no clean correspondence between
`config_preset` settings and OpenFold checkpoints; the only restraints are that
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
checkpoints are only compatible with template-less presets (`model_3` and above).
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size` is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config. to `None` in the config. If a value is specified, OpenFold will attempt to
dynamically tune it, considering the chunk size specified in the config as a
minimum. This tuning process automatically ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. It is also recommended
to disable this feature for very long chains (see below). To do so, set the
`tune_chunk_size` option in the config to `False`.
For large-scale batch inference, we offer an optional tracing mode, which
massively improves runtimes at the cost of a lengthy model compilation process.
To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the
following changes:
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the
solution offered by AlphaFold-Multimer, which is simply to average individual
template representations. Our version is modified slightly to accommodate
weights trained using the standard template algorithm. Using said weights, we
notice no significant difference in performance between our averaged template
embeddings and the standard ones. The second, `offload_templates`, temporarily
offloads individual template embeddings into CPU memory. The former is an
approximation while the latter is slightly slower; both are memory-efficient
and allow the model to utilize arbitrarily many templates across sequence
lengths. Both are disabled by default, and it is up to the user to determine
which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
### Training ### Training
...@@ -156,11 +241,10 @@ the following: ...@@ -156,11 +241,10 @@ the following:
```bash ```bash
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \ python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
data/pdb_mmcif/mmcif_files/ \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--cpus 16 \ --cpus 16 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
...@@ -216,32 +300,55 @@ python3 scripts/generate_chain_data_cache.py \ ...@@ -216,32 +300,55 @@ python3 scripts/generate_chain_data_cache.py \
where the `cluster_file` argument is a file of chain clusters, one cluster where the `cluster_file` argument is a file of chain clusters, one cluster
per line (e.g. [PDB40](https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt)). per line (e.g. [PDB40](https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt)).
Optionally, download an AlphaFold-style validation set from
[CAMEO](https://cameo3d.org) using `scripts/download_cameo.py`. Use the
resulting FASTA files to generate validation alignments and then specify
the validation set's location using the `--val_...` family of training script
flags.
Finally, call the training script: Finally, call the training script:
```bash ```bash
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \ python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/ \
2021-10-10 \ 2021-10-10 \
--template_release_dates_cache_path mmcif_cache.json \ --template_release_dates_cache_path mmcif_cache.json \
--precision 16 \ --precision bf16 \
--gpus 8 --replace_sampler_ddp=True \ --gpus 8 --replace_sampler_ddp=True \
--seed 42 \ # in multi-gpu settings, the seed must be specified --seed 4242022 \ # in multi-gpu settings, the seed must be specified
--deepspeed_config_path deepspeed_config.json \ --deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \ --checkpoint_every_epoch \
--resume_from_ckpt ckpt_dir/ \ --resume_from_ckpt ckpt_dir/ \
--train_chain_data_cache_path chain_data_cache.json --train_chain_data_cache_path chain_data_cache.json \
--obsolete_pdbs_file_path obsolete.dat
``` ```
where `--template_release_dates_cache_path` is a path to the mmCIF cache. where `--template_release_dates_cache_path` is a path to the mmCIF cache.
A suitable DeepSpeed configuration file can be generated with Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
training targets. A suitable DeepSpeed configuration file can be generated with
`scripts/build_deepspeed_config.py`. The training script is `scripts/build_deepspeed_config.py`. The training script is
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
and supports the full range of training options that entails, including and supports the full range of training options that entails, including
multi-node distributed training. For more information, consult PyTorch multi-node distributed training, validation, and so on. For more information,
Lightning documentation and the `--help` flag of the training script. consult PyTorch Lightning documentation and the `--help` flag of the training
script.
Note that the data directory can also contain PDB files previously output by
the model. These are treated as members of the self-distillation set and are Note that, despite its variable name, `mmcif_dir` can also contain PDB files
subjected to distillation-set-only preprocessing steps. or even ProteinNet .core files.
To emulate the AlphaFold training procedure, which uses a self-distillation set
subject to special preprocessing steps, use the family of `--distillation` flags.
In cases where it may be burdensome to create separate files for each chain's
alignments, alignment directories can be consolidated using the scripts in
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
consolidate an alignment directory into a pair of database and index files.
Once all alignment directories (or shards of a single alignment directory)
have been compiled, unify the indices with `unify_alignment_db_indices.py`. The
resulting index, `super.index`, can be passed to the training script flags
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
flags instead represent the directory containing the compiled alignment
databases. Both the training and distillation datasets can be compiled in this
way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
## Testing ## Testing
...@@ -297,7 +404,7 @@ docker run \ ...@@ -297,7 +404,7 @@ docker run \
-v /mnt/alphafold_database/:/database \ -v /mnt/alphafold_database/:/database \
-ti openfold:latest \ -ti openfold:latest \
python3 /opt/openfold/run_pretrained_openfold.py \ python3 /opt/openfold/run_pretrained_openfold.py \
/data/input.fasta \ /data/fasta_dir \
/database/pdb_mmcif/mmcif_files/ \ /database/pdb_mmcif/mmcif_files/ \
--uniref90_database_path /database/uniref90/uniref90.fasta \ --uniref90_database_path /database/uniref90/uniref90.fasta \
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
...@@ -310,7 +417,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \ ...@@ -310,7 +417,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--hhblits_binary_path /opt/conda/bin/hhblits \ --hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \ --hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \ --kalign_binary_path /opt/conda/bin/kalign \
--param_path /database/params/params_model_1.npz --openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
``` ```
## Copyright notice ## Copyright notice
...@@ -328,16 +435,20 @@ welcome pull requests from the community. ...@@ -328,16 +435,20 @@ welcome pull requests from the community.
## Citing this work ## Citing this work
For now, cite OpenFold as follows: Please cite our paper:
```bibtex ```bibtex
@software{Ahdritz_OpenFold_2021, @article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and AlQuraishi, Mohammed}, author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
doi = {10.5281/zenodo.5709539}, title = {OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization},
month = {11}, elocation-id = {2022.11.20.517210},
title = {{OpenFold}}, year = {2022},
url = {https://github.com/aqlaboratory/openfold}, doi = {10.1101/2022.11.20.517210},
year = {2021} publisher = {Cold Spring Harbor Laboratory},
abstract = {AlphaFold2 revolutionized structural biology with the ability to predict protein structures with exceptionally high accuracy. Its implementation, however, lacks the code and data required to train new models. These are necessary to (i) tackle new tasks, like protein-ligand complex structure prediction, (ii) investigate the process by which the model learns, which remains poorly understood, and (iii) assess the model{\textquoteright}s generalization capacity to unseen regions of fold space. Here we report OpenFold, a fast, memory-efficient, and trainable implementation of AlphaFold2, and OpenProteinSet, the largest public database of protein multiple sequence alignments. We use OpenProteinSet to train OpenFold from scratch, fully matching the accuracy of AlphaFold2. Having established parity, we assess OpenFold{\textquoteright}s capacity to generalize across fold space by retraining it using carefully designed datasets. We find that OpenFold is remarkably robust at generalizing despite extreme reductions in training set size and diversity, including near-complete elisions of classes of secondary structure elements. By analyzing intermediate structures produced by OpenFold during training, we also gain surprising insights into the manner in which the model learns to fold proteins, discovering that spatial dimensions are learned sequentially. Taken together, our studies demonstrate the power and utility of OpenFold, which we believe will prove to be a crucial new resource for the protein modeling community.},
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
journal = {bioRxiv}
} }
``` ```
......
...@@ -4,9 +4,19 @@ channels: ...@@ -4,9 +4,19 @@ channels:
- bioconda - bioconda
- pytorch - pytorch
dependencies: dependencies:
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==11.3.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip: - pip:
- biopython==1.79 - biopython==1.79
- deepspeed==0.5.9 - deepspeed==0.5.10
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0 - ml-collections==0.1.0
- numpy==1.21.2 - numpy==1.21.2
...@@ -16,15 +26,5 @@ dependencies: ...@@ -16,15 +26,5 @@ dependencies:
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10 - pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- pytorch::pytorch=1.10.*
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==10.2.*
- conda-forge::cudatoolkit-dev==10.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
"\n", "\n",
"OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling is excluded from OpenFold (recycling is enabled, however)).\n", "OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling is excluded from OpenFold (recycling is enabled, however)).\n",
"\n", "\n",
"In this notebook, OpenFold is run with DeepMind's publicly released parameters for AlphaFold 2.\n", "In this notebook, OpenFold is run with your choice of our original OpenFold parameters or DeepMind's publicly released parameters for AlphaFold 2.\n",
"\n", "\n",
"**Note**\n", "**Note**\n",
"\n", "\n",
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
"\n", "\n",
"**Licenses**\n", "**Licenses**\n",
"\n", "\n",
"This Colab uses the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n", "This Colab supports inference with the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n",
"\n", "\n",
"**More information**\n", "**More information**\n",
"\n", "\n",
...@@ -55,6 +55,33 @@ ...@@ -55,6 +55,33 @@
"FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." "FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
] ]
}, },
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
],
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": { "metadata": {
...@@ -63,10 +90,9 @@ ...@@ -63,10 +90,9 @@
}, },
"source": [ "source": [
"#@title Install third-party software\n", "#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n", "\n",
"#@markdown Please execute this cell by pressing the _Play_ button \n",
"#@markdown on the left to download and import third-party software \n",
"#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in DeepMind's README.)\n",
"\n", "\n",
"#@markdown **Note**: This installs the software on the Colab \n", "#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n", "#@markdown notebook in the cloud and not on your computer.\n",
...@@ -79,39 +105,46 @@ ...@@ -79,39 +105,46 @@
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n", "\n",
"try:\n", "try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" # Uninstall default Colab version of PyTorch.\n",
" # %shell pip uninstall -y torch\n",
"\n",
" %shell sudo apt install --quiet --yes hmmer\n", " %shell sudo apt install --quiet --yes hmmer\n",
" pbar.update(6)\n",
"\n", "\n",
" # Install py3dmol.\n", " # Install py3dmol.\n",
" %shell pip install py3dmol\n", " %shell pip install py3dmol\n",
" pbar.update(2)\n",
"\n", "\n",
" # Install OpenMM and pdbfixer.\n",
" %shell rm -rf /opt/conda\n", " %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n", " %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n", " https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n", " && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n", " && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
" pbar.update(9)\n",
"\n", "\n",
" PATH=%env PATH\n", " PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n", " %env PATH=/opt/conda/bin:{PATH}\n",
" pbar.update(80)\n", "\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python=3.7 \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79\n",
"\n", "\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
" pbar.update(2)\n",
"\n", "\n",
" %shell wget -q -P /content \\\n", " %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
" pbar.update(1)\n", "\n",
"except subprocess.CalledProcessError:\n", " # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise"
], ],
...@@ -125,112 +158,85 @@ ...@@ -125,112 +158,85 @@
"cellView": "form" "cellView": "form"
}, },
"source": [ "source": [
"#@title Download OpenFold\n", "#@title Install OpenFold\n",
"\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
"\n", "\n",
"GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n", "# Define constants\n",
"\n", "GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n", "ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"PARAMS_DIR = './openfold/openfold/resources/params'\n", "OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n", "ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, os.path.basename(ALPHAFOLD_PARAM_SOURCE_URL)\n",
")\n",
"\n", "\n",
"try:\n", "try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n", " with io.capture_output() as captured:\n",
" %shell rm -rf openfold\n",
" %shell git clone {GIT_REPO} openfold\n",
" pbar.update(8)\n",
" # Install the required versions of all dependencies.\n",
" %shell conda env update -n base --file openfold/environment.yml\n",
" # Run setup.py to install only Openfold.\n", " # Run setup.py to install only Openfold.\n",
" %shell pip3 install --no-dependencies ./openfold\n", " %shell rm -rf openfold\n",
" pbar.update(10)\n", " %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
"\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(relax_prediction):\n",
" %shell conda install -y -q -c conda-forge \\\n",
" openmm=7.5.1 \\\n",
" pdbfixer=1.7\n",
" \n",
" # Apply OpenMM patch.\n", " # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n", " %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n", " patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n", " popd\n",
" \n", "\n",
" %shell mkdir -p /content/openfold/resources\n", " if(weight_set == 'AlphaFold'):\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/resources\n", " %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
"\n", " %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" %shell mkdir --parents \"{PARAMS_DIR}\"\n", " %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n", " --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" pbar.update(27)\n", " %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
"\n", " elif(weight_set == 'OpenFold'):\n",
" %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n", " %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" --directory=\"{PARAMS_DIR}\" --preserve-permissions\n", " %shell aws s3 cp \\\n",
" %shell rm \"{PARAMS_PATH}\"\n", " --no-sign-request \\\n",
" pbar.update(55)\n", " --region us-east-1 \\\n",
"except subprocess.CalledProcessError:\n", " s3://openfold/openfold_params \"{OPENFOLD_PARAMS_DIR}\" \\\n",
" --recursive\n",
" else:\n",
" raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n", " print(captured)\n",
" raise" " raise"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "W4JpOs6oA-QS"
},
"source": [
"## Making a prediction\n",
"\n",
"Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ > _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.\n",
"\n",
"Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [ "source": [
"#@title Enter the amino acid sequence to fold ⬇️\n", "#@title Import Python packages\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n", "\n",
"MIN_SEQUENCE_LENGTH = 16\n", "import unittest.mock\n",
"MAX_SEQUENCE_LENGTH = 2500\n", "import sys\n",
"\n", "\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n", "sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n", "sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')\n",
"if len(sequence) < MIN_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')\n",
"if len(sequence) > MAX_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
},
"source": [
"#@title Search against genetic databases\n",
"\n", "\n",
"#@markdown Once this cell has been executed, you will see\n", "# Allows us to skip installing these packages\n",
"#@markdown statistics about the multiple sequence alignment \n", "unnecessary_modules = [\n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n", " \"dllogger\",\n",
"#@markdown you’ll see how well each residue is covered by similar \n", " \"pytorch_lightning\",\n",
"#@markdown sequences in the MSA.\n", " \"pytorch_lightning.utilities\",\n",
" \"pytorch_lightning.callbacks.early_stopping\",\n",
" \"pytorch_lightning.utilities.seed\",\n",
"]\n",
"for unnecessary_module in unnecessary_modules:\n",
" sys.modules[unnecessary_module] = unittest.mock.MagicMock()\n",
"\n", "\n",
"# --- Python imports ---\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n", "import os\n",
"os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n",
"os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n",
"\n", "\n",
"from urllib import request\n", "from urllib import request\n",
"from concurrent import futures\n", "from concurrent import futures\n",
...@@ -242,6 +248,18 @@ ...@@ -242,6 +248,18 @@
"import py3Dmol\n", "import py3Dmol\n",
"import torch\n", "import torch\n",
"\n", "\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n", "from openfold import config\n",
"from openfold.data import feature_pipeline\n", "from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n", "from openfold.data import parsers\n",
...@@ -249,20 +267,48 @@ ...@@ -249,20 +267,48 @@
"from openfold.data.tools import jackhmmer\n", "from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n", "from openfold.model import model\n",
"from openfold.np import protein\n", "from openfold.np import protein\n",
"from openfold.np.relax import relax\n", "if(relax_prediction):\n",
"from openfold.np.relax import utils\n", " from openfold.np.relax import relax\n",
" from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n", "from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n", "from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n", "\n",
"from IPython import display\n", "from IPython import display\n",
"from ipywidgets import GridspecLayout\n", "from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n", "from ipywidgets import Output"
],
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4JpOs6oA-QS"
},
"source": [
"## Making a prediction\n",
"\n", "\n",
"# Color bands for visualizing plddt\n", "Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n", ]
" (50, 70, '#FFDB13'),\n", },
" (70, 90, '#65CBF3'),\n", {
" (90, 100, '#0053D6')]\n", "cell_type": "code",
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
},
"source": [
"#@title Search against genetic databases\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n",
"\n", "\n",
"# --- Find the closest source ---\n", "# --- Find the closest source ---\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
...@@ -387,16 +433,30 @@ ...@@ -387,16 +433,30 @@
"#@markdown the obtained prediction will be automatically downloaded \n", "#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown to your computer.\n", "#@markdown to your computer.\n",
"\n", "\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [\n",
" (0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')\n",
"]\n",
"\n",
"# --- Run the model ---\n", "# --- Run the model ---\n",
"model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_1_ptm']\n", "model_names = [ \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
" 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n",
"\n", "\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n", "def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n", " return {\n",
" 'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),\n", " 'template_aatype': np.zeros((num_templates_, num_res_, 22), dtype=np.int64),\n",
" 'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),\n", " 'template_all_atom_positions': np.zeros((num_templates_, num_res_, 37, 3), dtype=np.float32),\n",
" 'template_all_atom_mask': torch.zeros(num_templates_, num_res_, 37),\n", " 'template_all_atom_mask': np.zeros((num_templates_, num_res_, 37), dtype=np.float32),\n",
" 'template_domain_names': torch.zeros(num_templates_),\n", " 'template_domain_names': np.zeros((num_templates_,), dtype=np.float32),\n",
" 'template_sum_probs': torch.zeros(num_templates_, 1),\n", " 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n", " }\n",
"\n", "\n",
"output_dir = 'prediction'\n", "output_dir = 'prediction'\n",
...@@ -407,21 +467,44 @@ ...@@ -407,21 +467,44 @@
"unrelaxed_proteins = {}\n", "unrelaxed_proteins = {}\n",
"\n", "\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for model_name in model_names:\n", " for i, model_name in list(enumerate(model_names)):\n",
" pbar.set_description(f'Running {model_name}')\n", " pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n", " num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n", " num_res = len(sequence)\n",
"\n", " \n",
" feature_dict = {}\n", " feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n", " feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n", " feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n", " feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n", "\n",
" cfg = config.model_config(model_name)\n", " if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n",
" else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
" else:\n",
" config_preset = \"model_1\"\n",
" if(\"_ptm_\" in model_name):\n",
" config_preset += \"_ptm\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
" openfold_model = model.AlphaFold(cfg)\n", " openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n", " openfold_model = openfold_model.eval()\n",
" params_name = os.path.join(PARAMS_DIR, f\"params_{model_name}.npz\")\n", " if(weight_set == \"AlphaFold\"):\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n", " params_name = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, f\"params_{config_preset}.npz\"\n",
" )\n",
" import_jax_weights_(openfold_model, params_name, version=config_preset)\n",
" elif(weight_set == \"OpenFold\"):\n",
" params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n",
" model_name,\n",
" )\n",
" d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n",
" else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n",
" openfold_model = openfold_model.cuda()\n", " openfold_model = openfold_model.cuda()\n",
"\n", "\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n", " pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
...@@ -470,7 +553,12 @@ ...@@ -470,7 +553,12 @@
" del prediction_result\n", " del prediction_result\n",
" pbar.update(n=1)\n", " pbar.update(n=1)\n",
"\n", "\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n", " # --- AMBER relax the best model ---\n",
" if(relax_prediction):\n",
" pbar.set_description(f'AMBER relaxation')\n", " pbar.set_description(f'AMBER relaxation')\n",
" amber_relaxer = relax.AmberRelaxation(\n", " amber_relaxer = relax.AmberRelaxation(\n",
" max_iterations=0,\n", " max_iterations=0,\n",
...@@ -478,12 +566,19 @@ ...@@ -478,12 +566,19 @@
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=20,\n", " max_outer_iterations=20,\n",
" use_gpu=True,\n", " use_gpu=False,\n",
" )\n", " )\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n", " relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name])\n", " prot=unrelaxed_proteins[best_model_name]\n",
" )\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
" best_pdb = relaxed_pdb\n",
"\n",
" pbar.update(n=1) # Finished AMBER relax.\n", " pbar.update(n=1) # Finished AMBER relax.\n",
"\n", "\n",
"# Construct multiclass b-factors to indicate confidence bands\n", "# Construct multiclass b-factors to indicate confidence bands\n",
...@@ -495,14 +590,7 @@ ...@@ -495,14 +590,7 @@
" banded_b_factors.append(idx)\n", " banded_b_factors.append(idx)\n",
" break\n", " break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n", "banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n", "to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
"\n",
"\n",
"# Write out the prediction\n",
"pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
"with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
"\n", "\n",
"# --- Visualise the prediction & confidence ---\n", "# --- Visualise the prediction & confidence ---\n",
"show_sidechains = True\n", "show_sidechains = True\n",
......
name: openfold_venv
channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
import copy import copy
import importlib
import ml_collections as mlc import ml_collections as mlc
...@@ -10,20 +11,92 @@ def set_inf(c, inf): ...@@ -10,20 +11,92 @@ def set_inf(c, inf):
c[k] = inf c[k] = inf
def model_config(name, train=False, low_prec=False): def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting.get(p)
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
),
(
"globals.use_lma",
"globals.use_flash",
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training": if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting # AF2 Suppl. Table 4, "initial training" setting
pass pass
elif name == "finetuning": elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False): ...@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True c.model.template.enabled = True
elif name == "model_3": elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1 # AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_4": elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2 # AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_5": elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3 # AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_1_ptm": elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False): ...@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_3_ptm": elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_4_ptm": elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
for k,v in multimer_model_config_update.items(): for k,v in multimer_model_config_update.items():
c.model[k] = v c.model[k] = v
...@@ -89,9 +168,23 @@ def model_config(name, train=False, low_prec=False): ...@@ -89,9 +168,23 @@ def model_config(name, train=False, low_prec=False):
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
...@@ -99,6 +192,8 @@ def model_config(name, train=False, low_prec=False): ...@@ -99,6 +192,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
enforce_config_constraints(c)
return c return c
...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool) ...@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder" NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder" NUM_MSA_SEQ = "msa placeholder"
...@@ -195,7 +291,6 @@ config = mlc.ConfigDict( ...@@ -195,7 +291,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -233,7 +328,8 @@ config = mlc.ConfigDict( ...@@ -233,7 +328,8 @@ config = mlc.ConfigDict(
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 512,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -246,6 +342,7 @@ config = mlc.ConfigDict( ...@@ -246,6 +342,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -258,6 +355,7 @@ config = mlc.ConfigDict( ...@@ -258,6 +355,7 @@ config = mlc.ConfigDict(
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
...@@ -274,6 +372,7 @@ config = mlc.ConfigDict( ...@@ -274,6 +372,7 @@ config = mlc.ConfigDict(
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 16, "num_workers": 16,
"pin_memory": True,
}, },
}, },
}, },
...@@ -281,6 +380,13 @@ config = mlc.ConfigDict( ...@@ -281,6 +380,13 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
...@@ -333,6 +439,7 @@ config = mlc.ConfigDict( ...@@ -333,6 +439,7 @@ config = mlc.ConfigDict(
"dropout_rate": 0.25, "dropout_rate": 0.25,
"tri_mul_first": False, "tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
}, },
"template_pointwise_attention": { "template_pointwise_attention": {
...@@ -349,6 +456,17 @@ config = mlc.ConfigDict( ...@@ -349,6 +456,17 @@ config = mlc.ConfigDict(
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False, "use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
...@@ -369,7 +487,8 @@ config = mlc.ConfigDict( ...@@ -369,7 +487,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"opm_first": False, "opm_first": False,
"clear_cache_between_blocks": True, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None, "ckpt": blocks_per_ckpt is not None,
...@@ -393,6 +512,7 @@ config = mlc.ConfigDict( ...@@ -393,6 +512,7 @@ config = mlc.ConfigDict(
"opm_first": False, "opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
...@@ -473,7 +593,7 @@ config = mlc.ConfigDict( ...@@ -473,7 +593,7 @@ config = mlc.ConfigDict(
"eps": 1e-4, "eps": 1e-4,
"weight": 1.0, "weight": 1.0,
}, },
"lddt": { "plddt_loss": {
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15.0, "cutoff": 15.0,
...@@ -482,6 +602,7 @@ config = mlc.ConfigDict( ...@@ -482,6 +602,7 @@ config = mlc.ConfigDict(
"weight": 0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 2.0, "weight": 2.0,
}, },
...@@ -503,7 +624,7 @@ config = mlc.ConfigDict( ...@@ -503,7 +624,7 @@ config = mlc.ConfigDict(
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
...@@ -607,6 +728,23 @@ multimer_model_config_update = { ...@@ -607,6 +728,23 @@ multimer_model_config_update = {
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": { "heads": {
"lddt": { "lddt": {
"no_bins": 50, "no_bins": 50,
......
...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None, filter_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None _structure_index: Optional[Any] = None,
): ):
""" """
Args: Args:
...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files. Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path: kalign_binary_path:
Path to kalign binary. Path to kalign binary.
max_template_hits: max_template_hits:
...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
...@@ -96,13 +111,41 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -96,13 +111,41 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None): if(alignment_index is not None):
self._chain_ids = list(_alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: self._chain_ids = list(os.listdir(alignment_dir))
self._chain_ids = [l.strip() for l in f.readlines()]
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None alignment_index = None
if(self._alignment_index is not None): if(self.alignment_index is not None):
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index, path, alignment_dir, alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb", pdb_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index, alignment_index=alignment_index,
_structure_index=structure_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Extension branch missing")
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index, alignment_index=alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats return feats
def __len__(self): def __len__(self):
...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -276,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -276,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
# Uniformly shuffle each dataset's indices # Uniformly shuffle each dataset's indices
...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx] chain_data_cache = dataset.chain_data_cache
while True: while True:
weights = [] weights = []
idx = [] idx = []
...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, stage="train"): def __call__(self, prots):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_filter_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None, _distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path self.train_filter_path = train_filter_path
self.distillation_mapping_path = distillation_mapping_path self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None self._distillation_structure_index = None
if(_alignment_index_path is not None): if(_distillation_structure_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._alignment_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, alignment_index=self.alignment_index,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, filter_path=self.distillation_filter_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None): if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path, generator = None
] if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
_output_raw=True,
) )
else: else:
self.eval_dataset = None self.eval_dataset = None
...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir, alignment_dir=self.predict_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.predict.max_template_hits, max_template_hits=self.config.predict.max_template_hits,
mode="predict", mode="predict",
) )
...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None dataset = None
if(stage == "train"): if(stage == "train"):
dataset = self.train_dataset dataset = self.train_dataset
# Filter the dataset, if necessary # Filter the dataset, if necessary
dataset.reroll() dataset.reroll()
elif(stage == "eval"): elif(stage == "eval"):
...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage) batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader( dl = OpenFoldDataLoader(
dataset, dataset,
......
...@@ -14,26 +14,17 @@ ...@@ -14,26 +14,17 @@
# limitations under the License. # limitations under the License.
import os import os
import copy
import collections import collections
import contextlib import contextlib
import dataclasses import dataclasses
import datetime
import json
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data import ( from openfold.data.templates import get_custom_template_features
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.parsers import Msa
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -78,6 +69,51 @@ def make_template_features( ...@@ -78,6 +69,51 @@ def make_template_features(
return template_features return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
if(n_templates != 0):
out_dicts.append(out_dict)
if(len(out_dicts) > 0):
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
else:
out_dict = empty_template_feats(sum(seq_lens))
return out_dict
def make_sequence_features( def make_sequence_features(
sequence: str, description: str, num_res: int sequence: str, description: str, num_res: int
) -> FeatureDict: ) -> FeatureDict:
...@@ -249,6 +285,41 @@ def run_msa_tool( ...@@ -249,6 +285,41 @@ def run_msa_tool(
return result return result
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [sequence]
deletion_matrix = [[0 for _ in sequence]]
msa_data_obj = parsers.Msa(sequences=msa_data, deletion_matrix=deletion_matrix, descriptions=None)
msa_features = make_msa_features([msa_data_obj])
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner: class AlignmentRunner:
"""Runs alignment tools and saves the results""" """Runs alignment tools and saves the results"""
...@@ -617,32 +688,30 @@ class DataPipeline: ...@@ -617,32 +688,30 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = {} msa_data = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
fp.seek(start) fp.seek(start)
msa = fp.read(size).decode("utf-8") msa = fp.read(size).decode("utf-8")
return msa return msa
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name) filename, ext = os.path.splitext(name)
if(ext == ".a3m"): if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m( msa = parsers.parse_a3m(
read_msa(start, size) read_msa(start, size)
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
# The "hmm_output" exception is a crude way to exclude # The "hmm_output" exception is a crude way to exclude
# multimer template hits. # multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename): elif(ext == ".sto" and not "hmm_output" == filename):
msa, deletion_matrix, _ = parsers.parse_stockholm( msa = parsers.parse_stockholm(read_msa(start, size))
read_msa(start, size) data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
...@@ -657,33 +726,35 @@ class DataPipeline: ...@@ -657,33 +726,35 @@ class DataPipeline:
if(ext == ".a3m"): if(ext == ".a3m"):
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read()) msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
elif(ext == ".sto" and not "hmm_output" == filename): elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp: with open(path, "r") as fp:
msa = parsers.parse_stockholm( msa = parsers.parse_stockholm(
fp.read() fp.read()
) )
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else: else:
continue continue
msas[f] = msa msa_data[f] = data
return msas return msa_data
def _parse_template_hit_files( def _parse_template_hit_files(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: str, input_sequence: str,
_alignment_index: Optional[Any] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size): def read_template(start, size):
fp.seek(start) fp.seek(start)
return fp.read(size).decode("utf-8") return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1] ext = os.path.splitext(name)[-1]
if(ext == ".hhr"): if(ext == ".hhr"):
...@@ -716,15 +787,46 @@ class DataPipeline: ...@@ -716,15 +787,46 @@ class DataPipeline:
return all_hits return all_hits
def _process_msa_feats( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, alignment_index: Optional[Any] = None
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = self._parse_msa_data(alignment_dir, _alignment_index) all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(len(msas) == 0): if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
raise ValueError( raise ValueError(
""" """
...@@ -732,13 +834,31 @@ class DataPipeline: ...@@ -732,13 +834,31 @@ class DataPipeline:
must be provided. must be provided.
""" """
) )
msa_data["dummy"] = Msa(
[input_sequence],
[[0 for _ in input_sequence]],
["dummy"]
)
msa_features = make_msa_features(list(msas.values())) deletion_matrix = [[0 for _ in input_sequence]]
msa_data["dummy"] = {
"msa": parsers.Msa(sequences=input_sequence, deletion_matrix=deletion_matrix, descriptions=None),
"deletion_matrix": deletion_matrix,
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
msas=msas
)
return msa_features return msa_features
...@@ -746,7 +866,7 @@ class DataPipeline: ...@@ -746,7 +866,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -763,7 +883,7 @@ class DataPipeline: ...@@ -763,7 +883,7 @@ class DataPipeline:
hits = self._parse_template_hit_files( hits = self._parse_template_hit_files(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index, alignment_index,
) )
template_features = make_template_features( template_features = make_template_features(
...@@ -778,7 +898,7 @@ class DataPipeline: ...@@ -778,7 +898,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return { return {
**sequence_features, **sequence_features,
...@@ -791,7 +911,7 @@ class DataPipeline: ...@@ -791,7 +911,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -812,7 +932,8 @@ class DataPipeline: ...@@ -812,7 +932,8 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index) alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -820,7 +941,7 @@ class DataPipeline: ...@@ -820,7 +941,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -831,7 +952,7 @@ class DataPipeline: ...@@ -831,7 +952,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -861,15 +982,16 @@ class DataPipeline: ...@@ -861,15 +982,16 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, input_sequence,
_alignment_index alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -877,7 +999,7 @@ class DataPipeline: ...@@ -877,7 +999,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -892,9 +1014,9 @@ class DataPipeline: ...@@ -892,9 +1014,9 @@ class DataPipeline:
hits = self._parse_template_hits( hits = self._parse_template_hits(
alignment_dir, alignment_dir,
input_sequence, alignment_index
_alignment_index
) )
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -905,6 +1027,98 @@ class DataPipeline: ...@@ -905,6 +1027,98 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap).
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
final_msa_obj = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert (len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
final_msa_obj.extend([parsers.Msa(sequences=msas[k], deletion_matrix=deletion_mats[k], descriptions=None)
for k in range(len(msas))])
msa_features = make_msa_features(
msas=final_msa_obj
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
class DataPipelineMultimer: class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features.""" """Runs the alignment tools and assembles the input features."""
...@@ -913,7 +1127,6 @@ class DataPipelineMultimer: ...@@ -913,7 +1127,6 @@ class DataPipelineMultimer:
monomer_data_pipeline: DataPipeline, monomer_data_pipeline: DataPipeline,
): ):
"""Initializes the data pipeline. """Initializes the data pipeline.
Args: Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system. the data pipeline for the monomer AlphaFold system.
...@@ -955,10 +1168,12 @@ class DataPipelineMultimer: ...@@ -955,10 +1168,12 @@ class DataPipelineMultimer:
def _all_seq_msa_features(self, fasta_path, alignment_dir): def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing.""" """Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") #TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.a3m")
with open(uniprot_msa_path, "r") as fp: with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read() uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string) msa = parsers.parse_a3m(uniprot_msa_string)
#msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa]) all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
......
...@@ -23,6 +23,9 @@ import torch ...@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -736,6 +739,7 @@ def make_atom14_positions(protein): ...@@ -736,6 +739,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences): for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0 renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
...@@ -781,10 +785,14 @@ def make_atom14_positions(protein): ...@@ -781,10 +785,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8): def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
...@@ -831,6 +839,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -831,6 +839,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
base_atom_pos = [batched_gather(
pos,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather( base_atom_pos = batched_gather(
all_atom_positions, all_atom_positions,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
...@@ -838,6 +855,15 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -838,6 +855,15 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
if is_multimer:
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
...@@ -864,8 +890,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -864,8 +890,12 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
...@@ -900,6 +930,12 @@ def atom37_to_frames(protein, eps=1e-8): ...@@ -900,6 +930,12 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
if is_multimer:
ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
else:
residx_rigidgroup_ambiguity_rot = Rotation( residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot rot_mats=residx_rigidgroup_ambiguity_rot
) )
......
...@@ -103,6 +103,21 @@ def np_example_to_features( ...@@ -103,6 +103,21 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
......
...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size( data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, mode_cfg.max_extra_msa,
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates, mode_cfg.max_templates,
) )
......
...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
......
...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool: ...@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject,
chain_id: str, chain_id: str,
_zero_center_positions: bool = True _zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
......
...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("") sequences.append("")
continue continue
elif line.startswith("#"):
continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
sequences[index] += line sequences[index] += line
......
...@@ -128,6 +128,22 @@ def _is_after_cutoff( ...@@ -128,6 +128,22 @@ def _is_after_cutoff(
return False return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete.""" """Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f: with open(obsolete_file_path) as f:
...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
from_id = line[20:24].lower() from_id = line[20:24].lower()
to_id = line[29:33].lower() to_id = line[29:33].lower()
result[from_id] = to_id result[from_id] = to_id
return result return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str): def generate_release_dates_cache(mmcif_dir: str, out_path: str):
...@@ -495,7 +511,7 @@ def _get_atom_positions( ...@@ -495,7 +511,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True, _zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
...@@ -912,6 +928,56 @@ def _process_single_hit( ...@@ -912,6 +928,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
filtered = list( filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True) sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
) )
idx = list(range(len(filtered))) idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered): if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered stk = self._shuffle_top_k_prefiltered
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
...@@ -17,7 +17,7 @@ from functools import partial ...@@ -17,7 +17,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, Optional
from openfold.utils import all_atom_multimer from openfold.utils import all_atom_multimer
from openfold.utils.feats import ( from openfold.utils.feats import (
...@@ -32,7 +32,7 @@ from openfold.model.template import ( ...@@ -32,7 +32,7 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
) )
from openfold.utils import geometry from openfold.utils import geometry
from openfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -96,10 +96,21 @@ class InputEmbedder(nn.Module): ...@@ -96,10 +96,21 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries).type(ri.dtype) reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
return self.linear_relpos(oh) d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
batch: Dict containing batch: Dict containing
...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module): ...@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
tf = batch["target_feat"]
ri = batch["residue_index"]
msa = batch["msa_feat"]
# [*, N_res, c_z] # [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf) tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__( def __init__(
self, self,
c_m: int, c_m: int,
...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module): ...@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(inplace_safe):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace( bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module):
device=x.device, device=x.device,
requires_grad=False, requires_grad=False,
) )
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = d + self.layer_norm_z(z) z_update = add(z_update, d, inplace_safe)
return m_update, z_update return m_update, z_update
...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__( def __init__(
self, self,
c_in: int, c_in: int,
...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module): ...@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True _mask_trans=True,
use_lma=False,
inplace_safe=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx), lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch, batch,
) )
single_template_embeds = {} # [*, N, N, C_t]
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.use_unit_vector, use_unit_vector=self.config.use_unit_vector,
...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module): ...@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module):
).to(z.dtype) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
template_embeds.append(single_template_embeds) del t
template_embeds = dict_multimap( if (not inplace_safe):
partial(torch.cat, dim=templ_dim), t_pair = torch.stack(pair_embeds, dim=templ_dim)
template_embeds,
) del pair_embeds
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
del t_pair
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size, use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {} ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t}) ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret return ret
...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
use_lma=False,
inplace_safe=False
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment