Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
cff-version: 1.2.0
message: "For now, cite OpenFold with its DOI."
authors:
- family-names: "Ahdritz"
given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta"
given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan"
given-names: "Sachin"
- family-names: "Xia"
given-names: "Qinghui"
- family-names: "Gerecke"
given-names: "William"
- family-names: "AlQuraishi"
given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold"
doi: 10.5281/zenodo.5709539
preferred-citation:
authors:
- family-names: "Ahdritz"
given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta"
given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan"
given-names: "Sachin"
orcid: https://orcid.org/0000-0002-6079-7627
- family-names: "Xia"
given-names: "Qinghui"
- family-names: "Gerecke"
given-names: "William"
orcid: https://orcid.org/0000-0002-9777-6192
- family-names: "O'Donnell"
given-names: "Timothy J"
orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12
url: "https://github.com/aqlaboratory/openfold"
url: "https://doi.org/10.1101/2022.11.20.517210"
FROM nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04
FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
RUN apt-get update && apt-get install -y wget cuda-minimal-build-10-2 git
# metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
......
![header ](imgs/OpenFold_viz_banner.jpg)
![header ](imgs/of_banner.png)
_Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
# OpenFold
A faithful PyTorch reproduction of DeepMind's
A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold).
## Features
......@@ -12,39 +14,43 @@ source inference code (v2.0.1). The sole exception is model ensembling, which
fared poorly in DeepMind's own ablation testing and is being phased out in future
DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is built to support inference with AlphaFold's original JAX weights.
It's also faster than the official code on GPU. Try it out for yourself with
our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with [DeepSpeed](https://github.com/microsoft/deepspeed) and with either `fp16`
or `bfloat16` half-precision.
OpenFold is equipped with an implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)), which
enables inference on extremely long chains.
We've modified [FastFold](https://github.com/hpcaitech/FastFold)'s custom CUDA
kernels to support in-place attention during inference and training. These use
latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
and we've trained it from scratch, matching the performance of the original.
We've publicly released model weights and our training data — some 400,000
MSAs and PDB70 template hit files — under a permissive license. Model weights
are available via scripts in this repository while the MSAs are hosted by the
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
OpenFold also supports inference using AlphaFold's official parameters, and
vice versa (see `scripts/convert_of_weights_to_jax.py`).
OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
kernels support in-place attention during inference and training. They use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively.
We also make available efficient scripts for generating alignments. We've
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention.
## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. Finally, some download scripts require `aria2c`.
installed on on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`.
For convenience, we provide a script that installs Miniconda locally, creates a
`conda` virtual environment, installs all Python dependencies, and downloads
useful resources (including DeepMind's pretrained parameters). Run:
useful resources, including both sets of model parameters. Run:
```bash
scripts/install_third_party_dependencies.sh
......@@ -76,14 +82,9 @@ To install the HH-suite to `/usr/bin`, run
## Usage
To download DeepMind's pretrained parameters and common ground truth data, run:
```bash
bash scripts/download_data.sh data/
```
You have two choices for downloading protein databases, depending on whether
you want to use DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use
DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
MMseqs2 instead. For the former, run:
......@@ -102,12 +103,25 @@ Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system).
If you're using your own precomputed MSAs or MSAs from the RODA repository,
there's no need to download these alignment databases. Simply make sure that
the `alignment_dir` contains one directory per chain and that each of these
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain. You
can use `scripts/flatten_roda.sh` to reformat RODA downloads in this way.
Note that the RODA alignments are NOT compatible with the recent .cif ground
truth files downloaded by `scripts/download_alphafold_dbs.sh`. To fetch .cif
files that match the RODA MSAs, once the alignments are flattened, use
`scripts/download_roda_pdbs.sh`. That script outputs a list of alignment dirs
for which matching .cif files could not be found. These should be removed from
the alignment directory.
Alternatively, you can use raw MSAs from
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
the database, use `scripts/prep_proteinnet_msas.py` to convert the data into
a format recognized by the OpenFold parser. The resulting directory becomes the
`alignment_dir` used in subsequent steps. Use `scripts/unpack_proteinnet.py` to
extract `.core` files from ProteinNet text files.
that database, use `scripts/prep_proteinnet_msas.py` to convert the data
into a format recognized by the OpenFold parser. The resulting directory
becomes the `alignment_dir` used in subsequent steps. Use
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
files.
For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using
......@@ -122,7 +136,7 @@ pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
target.fasta \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
......@@ -130,22 +144,93 @@ python3 run_pretrained_openfold.py \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:1 \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
```
where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here.
skip the expensive alignment computation here with
`--use_precomputed_alignments`.
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
respectively. For a breakdown of the differences between the different parameter
files, see the README downloaded to `openfold/resources/openfold_params/`. Since
OpenFold was trained under a newer training schedule than the one from which the
`model_n` config presets are derived, there is no clean correspondence between
`config_preset` settings and OpenFold checkpoints; the only restraints are that
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
checkpoints are only compatible with template-less presets (`model_3` and above).
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config.
to `None` in the config. If a value is specified, OpenFold will attempt to
dynamically tune it, considering the chunk size specified in the config as a
minimum. This tuning process automatically ensures consistently fast runtimes
regardless of input sequence length, but it also introduces some runtime
variability, which may be undesirable for certain users. It is also recommended
to disable this feature for very long chains (see below). To do so, set the
`tune_chunk_size` option in the config to `False`.
For large-scale batch inference, we offer an optional tracing mode, which
massively improves runtimes at the cost of a lengthy model compilation process.
To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the
following changes:
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the
solution offered by AlphaFold-Multimer, which is simply to average individual
template representations. Our version is modified slightly to accommodate
weights trained using the standard template algorithm. Using said weights, we
notice no significant difference in performance between our averaged template
embeddings and the standard ones. The second, `offload_templates`, temporarily
offloads individual template embeddings into CPU memory. The former is an
approximation while the latter is slightly slower; both are memory-efficient
and allow the model to utilize arbitrarily many templates across sequence
lengths. Both are disabled by default, and it is up to the user to determine
which best suits their needs, if either.
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.
- Disable FlashAttention, which seems unstable on long sequences.
Using the most conservative settings, we were able to run inference on a
4600-residue complex with a single A100. Compared to AlphaFold's own memory
offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
### Training
......@@ -156,11 +241,10 @@ the following:
```bash
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
data/uniref90/uniref90.fasta \
data/mgnify/mgy_clusters_2018_12.fa \
data/pdb70/pdb70 \
data/pdb_mmcif/mmcif_files/ \
data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--cpus 16 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
......@@ -216,32 +300,55 @@ python3 scripts/generate_chain_data_cache.py \
where the `cluster_file` argument is a file of chain clusters, one cluster
per line (e.g. [PDB40](https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt)).
Optionally, download an AlphaFold-style validation set from
[CAMEO](https://cameo3d.org) using `scripts/download_cameo.py`. Use the
resulting FASTA files to generate validation alignments and then specify
the validation set's location using the `--val_...` family of training script
flags.
Finally, call the training script:
```bash
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/ \
2021-10-10 \
--template_release_dates_cache_path mmcif_cache.json \
--precision 16 \
--precision bf16 \
--gpus 8 --replace_sampler_ddp=True \
--seed 42 \ # in multi-gpu settings, the seed must be specified
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
--deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \
--resume_from_ckpt ckpt_dir/ \
--train_chain_data_cache_path chain_data_cache.json
--train_chain_data_cache_path chain_data_cache.json \
--obsolete_pdbs_file_path obsolete.dat
```
where `--template_release_dates_cache_path` is a path to the mmCIF cache.
A suitable DeepSpeed configuration file can be generated with
Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
training targets. A suitable DeepSpeed configuration file can be generated with
`scripts/build_deepspeed_config.py`. The training script is
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
and supports the full range of training options that entails, including
multi-node distributed training. For more information, consult PyTorch
Lightning documentation and the `--help` flag of the training script.
Note that the data directory can also contain PDB files previously output by
the model. These are treated as members of the self-distillation set and are
subjected to distillation-set-only preprocessing steps.
multi-node distributed training, validation, and so on. For more information,
consult PyTorch Lightning documentation and the `--help` flag of the training
script.
Note that, despite its variable name, `mmcif_dir` can also contain PDB files
or even ProteinNet .core files.
To emulate the AlphaFold training procedure, which uses a self-distillation set
subject to special preprocessing steps, use the family of `--distillation` flags.
In cases where it may be burdensome to create separate files for each chain's
alignments, alignment directories can be consolidated using the scripts in
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
consolidate an alignment directory into a pair of database and index files.
Once all alignment directories (or shards of a single alignment directory)
have been compiled, unify the indices with `unify_alignment_db_indices.py`. The
resulting index, `super.index`, can be passed to the training script flags
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
flags instead represent the directory containing the compiled alignment
databases. Both the training and distillation datasets can be compiled in this
way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
## Testing
......@@ -297,7 +404,7 @@ docker run \
-v /mnt/alphafold_database/:/database \
-ti openfold:latest \
python3 /opt/openfold/run_pretrained_openfold.py \
/data/input.fasta \
/data/fasta_dir \
/database/pdb_mmcif/mmcif_files/ \
--uniref90_database_path /database/uniref90/uniref90.fasta \
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
......@@ -310,7 +417,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \
--param_path /database/params/params_model_1.npz
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
```
## Copyright notice
......@@ -328,16 +435,20 @@ welcome pull requests from the community.
## Citing this work
For now, cite OpenFold as follows:
Please cite our paper:
```bibtex
@software{Ahdritz_OpenFold_2021,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and AlQuraishi, Mohammed},
doi = {10.5281/zenodo.5709539},
month = {11},
title = {{OpenFold}},
url = {https://github.com/aqlaboratory/openfold},
year = {2021}
@article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
title = {OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization},
elocation-id = {2022.11.20.517210},
year = {2022},
doi = {10.1101/2022.11.20.517210},
publisher = {Cold Spring Harbor Laboratory},
abstract = {AlphaFold2 revolutionized structural biology with the ability to predict protein structures with exceptionally high accuracy. Its implementation, however, lacks the code and data required to train new models. These are necessary to (i) tackle new tasks, like protein-ligand complex structure prediction, (ii) investigate the process by which the model learns, which remains poorly understood, and (iii) assess the model{\textquoteright}s generalization capacity to unseen regions of fold space. Here we report OpenFold, a fast, memory-efficient, and trainable implementation of AlphaFold2, and OpenProteinSet, the largest public database of protein multiple sequence alignments. We use OpenProteinSet to train OpenFold from scratch, fully matching the accuracy of AlphaFold2. Having established parity, we assess OpenFold{\textquoteright}s capacity to generalize across fold space by retraining it using carefully designed datasets. We find that OpenFold is remarkably robust at generalizing despite extreme reductions in training set size and diversity, including near-complete elisions of classes of secondary structure elements. By analyzing intermediate structures produced by OpenFold during training, we also gain surprising insights into the manner in which the model learns to fold proteins, discovering that spatial dimensions are learned sequentially. Taken together, our studies demonstrate the power and utility of OpenFold, which we believe will prove to be a crucial new resource for the protein modeling community.},
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
journal = {bioRxiv}
}
```
......
......@@ -4,9 +4,19 @@ channels:
- bioconda
- pytorch
dependencies:
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==11.3.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- biopython==1.79
- deepspeed==0.5.9
- deepspeed==0.5.10
- dm-tree==0.1.6
- ml-collections==0.1.0
- numpy==1.21.2
......@@ -16,15 +26,5 @@ dependencies:
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/NVIDIA/dllogger.git
- pytorch::pytorch=1.10.*
- conda-forge::python=3.7
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- conda-forge::cudatoolkit==10.2.*
- conda-forge::cudatoolkit-dev==10.*
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
......@@ -31,7 +31,7 @@
"\n",
"OpenFold is a trainable PyTorch reimplementation of AlphaFold 2. For the purposes of inference, it is practically identical to the original (\"practically\" because ensembling is excluded from OpenFold (recycling is enabled, however)).\n",
"\n",
"In this notebook, OpenFold is run with DeepMind's publicly released parameters for AlphaFold 2.\n",
"In this notebook, OpenFold is run with your choice of our original OpenFold parameters or DeepMind's publicly released parameters for AlphaFold 2.\n",
"\n",
"**Note**\n",
"\n",
......@@ -43,7 +43,7 @@
"\n",
"**Licenses**\n",
"\n",
"This Colab uses the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n",
"This Colab supports inference with the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n",
"\n",
"**More information**\n",
"\n",
......@@ -55,6 +55,33 @@
"FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@markdown ### Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"\n",
"#@markdown ### Configure the model ⬇️\n",
"\n",
"weight_set = 'OpenFold' #@param [\"OpenFold\", \"AlphaFold\"]\n",
"relax_prediction = True #@param {type:\"boolean\"}\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. OpenFold only supports 20 standard amino acids as inputs.')\n",
"\n",
"#@markdown After making your selections, execute this cell by pressing the\n",
"#@markdown *Play* button on the left."
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
......@@ -63,10 +90,9 @@
},
"source": [
"#@title Install third-party software\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"#@markdown Please execute this cell by pressing the _Play_ button \n",
"#@markdown on the left to download and import third-party software \n",
"#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in DeepMind's README.)\n",
"\n",
"#@markdown **Note**: This installs the software on the Colab \n",
"#@markdown notebook in the cloud and not on your computer.\n",
......@@ -79,39 +105,46 @@
"TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
"\n",
"try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n",
" # Uninstall default Colab version of PyTorch.\n",
" # %shell pip uninstall -y torch\n",
"\n",
" %shell sudo apt install --quiet --yes hmmer\n",
" pbar.update(6)\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
" pbar.update(2)\n",
"\n",
" # Install OpenMM and pdbfixer.\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
" pbar.update(9)\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
" pbar.update(80)\n",
"\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
" pbar.update(2)\n",
"\n",
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
" pbar.update(1)\n",
"except subprocess.CalledProcessError:\n",
" with io.capture_output() as captured:\n",
" %shell sudo apt install --quiet --yes hmmer\n",
"\n",
" # Install py3dmol.\n",
" %shell pip install py3dmol\n",
"\n",
" %shell rm -rf /opt/conda\n",
" %shell wget -q -P /tmp \\\n",
" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n",
" && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n",
" && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n",
"\n",
" PATH=%env PATH\n",
" %env PATH=/opt/conda/bin:{PATH}\n",
"\n",
" # Install the required versions of all dependencies.\n",
" %shell conda install -y -q -c conda-forge -c bioconda \\\n",
" kalign2=2.04 \\\n",
" hhsuite=3.3.0 \\\n",
" python=3.7 \\\n",
" 2>&1 1>/dev/null\n",
" %shell pip install -q \\\n",
" ml-collections==0.1.0 \\\n",
" PyYAML==5.4.1 \\\n",
" biopython==1.79\n",
"\n",
" # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n",
" %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n",
" %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n",
"\n",
" %shell wget -q -P /content \\\n",
" https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n",
"\n",
" # Install AWS CLI\n",
" %shell curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\n",
" %shell unzip -qq awscliv2.zip\n",
" %shell sudo ./aws/install\n",
" %shell rm awscliv2.zip\n",
" %shell rm -rf ./aws\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
......@@ -125,112 +158,85 @@
"cellView": "form"
},
"source": [
"#@title Download OpenFold\n",
"\n",
"#@title Install OpenFold\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"GIT_REPO = 'https://github.com/aqlaboratory/openfold'\n",
"\n",
"SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"PARAMS_DIR = './openfold/openfold/resources/params'\n",
"PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n",
"# Define constants\n",
"GIT_REPO='https://github.com/aqlaboratory/openfold'\n",
"ALPHAFOLD_PARAM_SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar'\n",
"OPENFOLD_PARAMS_DIR = './openfold/openfold/resources/openfold_params'\n",
"ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'\n",
"ALPHAFOLD_PARAMS_PATH = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, os.path.basename(ALPHAFOLD_PARAM_SOURCE_URL)\n",
")\n",
"\n",
"try:\n",
" with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" with io.capture_output() as captured:\n",
" %shell rm -rf openfold\n",
" %shell git clone {GIT_REPO} openfold\n",
" pbar.update(8)\n",
" # Install the required versions of all dependencies.\n",
" %shell conda env update -n base --file openfold/environment.yml\n",
" # Run setup.py to install only Openfold.\n",
" %shell pip3 install --no-dependencies ./openfold\n",
" pbar.update(10)\n",
"\n",
" with io.capture_output() as captured:\n",
" # Run setup.py to install only Openfold.\n",
" %shell rm -rf openfold\n",
" %shell git clone \"{GIT_REPO}\" openfold 2>&1 1> /dev/null\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/openfold/resources\n",
" %shell /usr/bin/python3 -m pip install -q ./openfold\n",
"\n",
" if(relax_prediction):\n",
" %shell conda install -y -q -c conda-forge \\\n",
" openmm=7.5.1 \\\n",
" pdbfixer=1.7\n",
" \n",
" # Apply OpenMM patch.\n",
" %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n",
" patch -p0 < /content/openfold/lib/openmm.patch && \\\n",
" popd\n",
" \n",
" %shell mkdir -p /content/openfold/resources\n",
" %shell cp -f /content/stereo_chemical_props.txt /content/openfold/resources\n",
"\n",
" %shell mkdir --parents \"{PARAMS_DIR}\"\n",
" %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n",
" pbar.update(27)\n",
"\n",
" %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n",
" --directory=\"{PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{PARAMS_PATH}\"\n",
" pbar.update(55)\n",
"except subprocess.CalledProcessError:\n",
"\n",
" if(weight_set == 'AlphaFold'):\n",
" %shell mkdir --parents \"{ALPHAFOLD_PARAMS_DIR}\"\n",
" %shell wget -O {ALPHAFOLD_PARAMS_PATH} {ALPHAFOLD_PARAM_SOURCE_URL}\n",
" %shell tar --extract --verbose --file=\"{ALPHAFOLD_PARAMS_PATH}\" \\\n",
" --directory=\"{ALPHAFOLD_PARAMS_DIR}\" --preserve-permissions\n",
" %shell rm \"{ALPHAFOLD_PARAMS_PATH}\"\n",
" elif(weight_set == 'OpenFold'):\n",
" %shell mkdir --parents \"{OPENFOLD_PARAMS_DIR}\"\n",
" %shell aws s3 cp \\\n",
" --no-sign-request \\\n",
" --region us-east-1 \\\n",
" s3://openfold/openfold_params \"{OPENFOLD_PARAMS_DIR}\" \\\n",
" --recursive\n",
" else:\n",
" raise ValueError(\"Invalid weight set\")\n",
"except subprocess.CalledProcessError as captured:\n",
" print(captured)\n",
" raise"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4JpOs6oA-QS"
},
"source": [
"## Making a prediction\n",
"\n",
"Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ > _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.\n",
"\n",
"Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rowN0bVYLe9n",
"cellView": "form"
},
"source": [
"#@title Enter the amino acid sequence to fold ⬇️\n",
"sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@title Import Python packages\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n",
"\n",
"MIN_SEQUENCE_LENGTH = 16\n",
"MAX_SEQUENCE_LENGTH = 2500\n",
"import unittest.mock\n",
"import sys\n",
"\n",
"# Remove all whitespaces, tabs and end lines; upper-case\n",
"sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n",
"aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n",
"if not set(sequence).issubset(aatypes):\n",
" raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')\n",
"if len(sequence) < MIN_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')\n",
"if len(sequence) > MAX_SEQUENCE_LENGTH:\n",
" raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
},
"source": [
"#@title Search against genetic databases\n",
"sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n",
"# Allows us to skip installing these packages\n",
"unnecessary_modules = [\n",
" \"dllogger\",\n",
" \"pytorch_lightning\",\n",
" \"pytorch_lightning.utilities\",\n",
" \"pytorch_lightning.callbacks.early_stopping\",\n",
" \"pytorch_lightning.utilities.seed\",\n",
"]\n",
"for unnecessary_module in unnecessary_modules:\n",
" sys.modules[unnecessary_module] = unittest.mock.MagicMock()\n",
"\n",
"# --- Python imports ---\n",
"import sys\n",
"sys.path.append('/opt/conda/lib/python3.7/site-packages')\n",
"import os\n",
"os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n",
"os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n",
"\n",
"from urllib import request\n",
"from concurrent import futures\n",
......@@ -242,6 +248,18 @@
"import py3Dmol\n",
"import torch\n",
"\n",
"# A filthy hack to avoid slow Linear layer initialization\n",
"import openfold.model.primitives\n",
"\n",
"def __default_linear_init__(self, *args, **kwargs):\n",
" return torch.nn.Linear.__init__(\n",
" self, \n",
" *args[:2], \n",
" **{k:v for k,v in kwargs.items() if k == \"bias\"}\n",
" )\n",
"\n",
"openfold.model.primitives.Linear.__init__ = __default_linear_init__\n",
"\n",
"from openfold import config\n",
"from openfold.data import feature_pipeline\n",
"from openfold.data import parsers\n",
......@@ -249,20 +267,48 @@
"from openfold.data.tools import jackhmmer\n",
"from openfold.model import model\n",
"from openfold.np import protein\n",
"from openfold.np.relax import relax\n",
"from openfold.np.relax import utils\n",
"if(relax_prediction):\n",
" from openfold.np.relax import relax\n",
" from openfold.np.relax import utils\n",
"from openfold.utils.import_weights import import_jax_weights_\n",
"from openfold.utils.tensor_utils import tensor_tree_map\n",
"\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n",
"from ipywidgets import Output"
],
"metadata": {
"id": "_FpxxMo-mvcP",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4JpOs6oA-QS"
},
"source": [
"## Making a prediction\n",
"\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')]\n",
"Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2tTeTTsLKPjB",
"cellView": "form"
},
"source": [
"#@title Search against genetic databases\n",
"\n",
"#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n",
"#@markdown (MSA) that will be used by OpenFold. In particular, \n",
"#@markdown you’ll see how well each residue is covered by similar \n",
"#@markdown sequences in the MSA.\n",
"\n",
"# --- Find the closest source ---\n",
"test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n",
......@@ -387,16 +433,30 @@
"#@markdown the obtained prediction will be automatically downloaded \n",
"#@markdown to your computer.\n",
"\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [\n",
" (0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')\n",
"]\n",
"\n",
"# --- Run the model ---\n",
"model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_1_ptm']\n",
"model_names = [ \n",
" 'finetuning_3.pt', \n",
" 'finetuning_4.pt', \n",
" 'finetuning_5.pt', \n",
" 'finetuning_ptm_2.pt',\n",
" 'finetuning_no_templ_ptm_1.pt'\n",
"]\n",
"\n",
"def _placeholder_template_feats(num_templates_, num_res_):\n",
" return {\n",
" 'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),\n",
" 'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),\n",
" 'template_all_atom_mask': torch.zeros(num_templates_, num_res_, 37),\n",
" 'template_domain_names': torch.zeros(num_templates_),\n",
" 'template_sum_probs': torch.zeros(num_templates_, 1),\n",
" 'template_aatype': np.zeros((num_templates_, num_res_, 22), dtype=np.int64),\n",
" 'template_all_atom_positions': np.zeros((num_templates_, num_res_, 37, 3), dtype=np.float32),\n",
" 'template_all_atom_mask': np.zeros((num_templates_, num_res_, 37), dtype=np.float32),\n",
" 'template_domain_names': np.zeros((num_templates_,), dtype=np.float32),\n",
" 'template_sum_probs': np.zeros((num_templates_, 1), dtype=np.float32),\n",
" }\n",
"\n",
"output_dir = 'prediction'\n",
......@@ -407,21 +467,44 @@
"unrelaxed_proteins = {}\n",
"\n",
"with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
" for model_name in model_names:\n",
" for i, model_name in list(enumerate(model_names)):\n",
" pbar.set_description(f'Running {model_name}')\n",
" num_templates = 1 # dummy number --- is ignored\n",
" num_res = len(sequence)\n",
"\n",
" \n",
" feature_dict = {}\n",
" feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))\n",
" feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n",
" feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n",
"\n",
" cfg = config.model_config(model_name)\n",
" if(weight_set == \"AlphaFold\"):\n",
" config_preset = f\"model_{i}\"\n",
" else:\n",
" if(\"_no_templ_\" in model_name):\n",
" config_preset = \"model_3\"\n",
" else:\n",
" config_preset = \"model_1\"\n",
" if(\"_ptm_\" in model_name):\n",
" config_preset += \"_ptm\"\n",
"\n",
" cfg = config.model_config(config_preset)\n",
" openfold_model = model.AlphaFold(cfg)\n",
" openfold_model = openfold_model.eval()\n",
" params_name = os.path.join(PARAMS_DIR, f\"params_{model_name}.npz\")\n",
" import_jax_weights_(openfold_model, params_name, version=model_name)\n",
" if(weight_set == \"AlphaFold\"):\n",
" params_name = os.path.join(\n",
" ALPHAFOLD_PARAMS_DIR, f\"params_{config_preset}.npz\"\n",
" )\n",
" import_jax_weights_(openfold_model, params_name, version=config_preset)\n",
" elif(weight_set == \"OpenFold\"):\n",
" params_name = os.path.join(\n",
" OPENFOLD_PARAMS_DIR,\n",
" model_name,\n",
" )\n",
" d = torch.load(params_name)\n",
" openfold_model.load_state_dict(d)\n",
" else:\n",
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
"\n",
" openfold_model = openfold_model.cuda()\n",
"\n",
" pipeline = feature_pipeline.FeaturePipeline(cfg.data)\n",
......@@ -470,20 +553,32 @@
" del prediction_result\n",
" pbar.update(n=1)\n",
"\n",
" # --- AMBER relax the best model ---\n",
" pbar.set_description(f'AMBER relaxation')\n",
" amber_relaxer = relax.AmberRelaxation(\n",
" max_iterations=0,\n",
" tolerance=2.39,\n",
" stiffness=10.0,\n",
" exclude_residues=[],\n",
" max_outer_iterations=20,\n",
" use_gpu=True,\n",
" )\n",
" # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name])\n",
" best_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n",
"\n",
" # --- AMBER relax the best model ---\n",
" if(relax_prediction):\n",
" pbar.set_description(f'AMBER relaxation')\n",
" amber_relaxer = relax.AmberRelaxation(\n",
" max_iterations=0,\n",
" tolerance=2.39,\n",
" stiffness=10.0,\n",
" exclude_residues=[],\n",
" max_outer_iterations=20,\n",
" use_gpu=False,\n",
" )\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n",
" prot=unrelaxed_proteins[best_model_name]\n",
" )\n",
"\n",
" # Write out the prediction\n",
" pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
" with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
" best_pdb = relaxed_pdb\n",
"\n",
" pbar.update(n=1) # Finished AMBER relax.\n",
"\n",
"# Construct multiclass b-factors to indicate confidence bands\n",
......@@ -495,14 +590,7 @@
" banded_b_factors.append(idx)\n",
" break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n",
"\n",
"\n",
"# Write out the prediction\n",
"pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n",
"with open(pred_output_path, 'w') as f:\n",
" f.write(relaxed_pdb)\n",
"\n",
"to_visualize_pdb = utils.overwrite_b_factors(best_pdb, banded_b_factors)\n",
"\n",
"# --- Visualise the prediction & confidence ---\n",
"show_sidechains = True\n",
......@@ -699,4 +787,4 @@
]
}
]
}
}
\ No newline at end of file
name: openfold_venv
channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
import copy
import importlib
import ml_collections as mlc
......@@ -10,20 +11,92 @@ def set_inf(c, inf):
c[k] = inf
def model_config(name, train=False, low_prec=False):
def enforce_config_constraints(config):
def string_to_setting(s):
path = s.split('.')
setting = config
for p in path:
setting = setting.get(p)
return setting
mutually_exclusive_bools = [
(
"model.template.average_templates",
"model.template.offload_templates"
),
(
"globals.use_lma",
"globals.use_flash",
),
]
for s1, s2 in mutually_exclusive_bools:
s1_setting = string_to_setting(s1)
s2_setting = string_to_setting(s2)
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
):
config.model.template.offload_templates = True
def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
):
c = copy.deepcopy(config)
# TRAINING PRESETS
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_ptm":
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.crop_size = 384
c.data.train.max_extra_msa = 5120
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# INFERENCE PRESETS
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
......@@ -36,17 +109,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
......@@ -61,12 +137,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
......@@ -76,6 +154,7 @@ def model_config(name, train=False, low_prec=False):
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.globals.is_multimer = True
c.loss.masked_msa.num_classes = 22
for k,v in multimer_model_config_update.items():
c.model[k] = v
......@@ -89,16 +168,32 @@ def model_config(name, train=False, low_prec=False):
else:
raise ValueError("Invalid model name")
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
c.globals.use_lma = False
c.globals.offload_inference = False
c.model.template.average_templates = False
c.model.template.offload_templates = False
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
enforce_config_constraints(c)
return c
......@@ -114,6 +209,7 @@ tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
......@@ -195,7 +291,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
......@@ -233,7 +328,8 @@ config = mlc.ConfigDict(
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_msa_clusters": 512,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
......@@ -246,6 +342,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
......@@ -258,6 +355,7 @@ config = mlc.ConfigDict(
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
......@@ -274,6 +372,7 @@ config = mlc.ConfigDict(
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
"pin_memory": True,
},
},
},
......@@ -281,6 +380,13 @@ config = mlc.ConfigDict(
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
......@@ -333,6 +439,7 @@ config = mlc.ConfigDict(
"dropout_rate": 0.25,
"tri_mul_first": False,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
},
"template_pointwise_attention": {
......@@ -349,6 +456,17 @@ config = mlc.ConfigDict(
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates": False,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates. Automatically
# enabled if offload_inference is set.
"offload_templates": False,
},
"extra_msa": {
"extra_msa_embedder": {
......@@ -369,7 +487,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"clear_cache_between_blocks": True,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
......@@ -393,6 +512,7 @@ config = mlc.ConfigDict(
"opm_first": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
"eps": eps, # 1e-10,
},
......@@ -473,7 +593,7 @@ config = mlc.ConfigDict(
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
......@@ -482,6 +602,7 @@ config = mlc.ConfigDict(
"weight": 0.01,
},
"masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8,
"weight": 2.0,
},
......@@ -503,7 +624,7 @@ config = mlc.ConfigDict(
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"weight": 0.,
"enabled": tm_enabled,
},
"eps": eps,
......@@ -607,6 +728,23 @@ multimer_model_config_update = {
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
......
......@@ -28,16 +28,18 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
_structure_index: Optional[Any] = None,
):
"""
Args:
......@@ -55,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
Path to kalign binary.
max_template_hits:
......@@ -79,12 +84,22 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._alignment_index = _alignment_index
self._structure_index = _structure_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
......@@ -96,14 +111,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_ids = list(os.listdir(alignment_dir))
if(filter_path is not None):
with open(filter_path, "r") as f:
chains_to_include = set([l.strip() for l in f.readlines()])
self._chain_ids = [
c for c in self._chain_ids if c in chains_to_include
]
if self.chain_data_cache is not None:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids = self._chain_ids
self._chain_ids = [
c for c in self._chain_ids if c in self.chain_data_cache
]
if len(self._chain_ids) < len(original_chain_ids):
missing = [
c for c in original_chain_ids
if c not in self.chain_data_cache
]
max_to_print = 10
missing_examples = ", ".join(missing[:max_to_print])
if len(missing) > max_to_print:
missing_examples += ", ..."
logging.warning(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s).",
len(missing),
missing_examples,
chain_data_cache_path)
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
......@@ -125,7 +168,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -144,7 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
alignment_index=alignment_index
)
return data
......@@ -159,10 +202,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
......@@ -173,30 +216,51 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(os.path.exists(path + ".core")):
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index,
path, alignment_dir, alignment_index,
)
elif(os.path.exists(path + ".pdb")):
elif(ext == ".pdb"):
structure_index = None
if(self._structure_index is not None):
structure_index = self._structure_index[name]
data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb",
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
alignment_index=alignment_index,
_structure_index=structure_index,
)
else:
raise ValueError("Invalid file type")
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
alignment_index=alignment_index,
)
if(self._output_raw):
......@@ -206,6 +270,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode
)
feats["batch_idx"] = torch.tensor(
[idx for _ in range(feats["aatype"].shape[-1])],
dtype=torch.int64,
device=feats["aatype"].device)
return feats
def __len__(self):
......@@ -265,9 +334,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
"""
def __init__(self,
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
probabilities: Sequence[float],
epoch_len: int,
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
......@@ -275,11 +343,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
......@@ -298,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
chain_data_cache = dataset.chain_data_cache
while True:
weights = []
idx = []
......@@ -355,20 +418,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator:
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
......@@ -388,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling):
recycling_probs = [
......@@ -480,13 +527,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
_distillation_structure_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -507,8 +556,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.train_filter_path = train_filter_path
self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
......@@ -539,10 +588,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
self._distillation_structure_index = None
if(_distillation_structure_index_path is not None):
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
......@@ -560,27 +619,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
alignment_index=self.alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
)
d_prob = self.config.train.distillation_prob
......@@ -588,23 +649,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
probabilities = [1. - d_prob, d_prob]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
probabilities = [1.]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
generator=generator,
_roll_at_init=False,
)
......@@ -612,10 +671,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
_output_raw=True,
)
else:
self.eval_dataset = None
......@@ -623,7 +681,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
filter_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
......@@ -636,7 +694,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
......@@ -646,7 +703,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage)
batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader(
dataset,
......
......@@ -14,28 +14,19 @@
# limitations under the License.
import os
import copy
import collections
import contextlib
import dataclasses
import datetime
import json
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
from openfold.data import (
templates,
parsers,
mmcif_parsing,
msa_identifiers,
msa_pairing,
feature_processing_multimer,
)
from openfold.data.parsers import Msa
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
......@@ -46,7 +37,7 @@ TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
......@@ -78,6 +69,51 @@ def make_template_features(
return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
if(n_templates != 0):
out_dicts.append(out_dict)
if(len(out_dicts) > 0):
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
else:
out_dict = empty_template_feats(sum(seq_lens))
return out_dict
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
......@@ -138,13 +174,13 @@ def make_mmcif_features(
def _aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
def make_protein_features(
protein_object: protein.Protein,
protein_object: protein.Protein,
description: str,
_is_distillation: bool = False,
) -> FeatureDict:
......@@ -243,12 +279,47 @@ def run_msa_tool(
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else:
result = msa_runner.query(fasta_path)[0]
with open(msa_out_path, "w") as f:
f.write(result[msa_format])
return result
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [sequence]
deletion_matrix = [[0 for _ in sequence]]
msa_data_obj = parsers.Msa(sequences=msa_data, deletion_matrix=deletion_matrix, descriptions=None)
msa_features = make_msa_features([msa_data_obj])
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
......@@ -282,13 +353,13 @@ class AlignmentRunner:
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
......@@ -332,7 +403,7 @@ class AlignmentRunner:
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
......@@ -340,7 +411,7 @@ class AlignmentRunner:
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
......@@ -375,7 +446,7 @@ class AlignmentRunner:
database_path=uniprot_database_path
)
if(template_searcher is not None and
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
......@@ -383,7 +454,7 @@ class AlignmentRunner:
)
self.template_searcher = template_searcher
def run(
self,
fasta_path: str,
......@@ -457,9 +528,9 @@ class AlignmentRunner:
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
......@@ -529,10 +600,10 @@ def convert_monomer_features(
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
......@@ -540,7 +611,7 @@ def int_id_to_str_id(num: int) -> str:
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
......@@ -553,11 +624,11 @@ def add_assembly_features(
all_chain_features: MutableMapping[str, FeatureDict],
) -> MutableMapping[str, FeatureDict]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
......@@ -572,7 +643,7 @@ def add_assembly_features(
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = {}
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
......@@ -590,7 +661,7 @@ def add_assembly_features(
entity_id * np.ones(seq_length)
).astype(np.int64)
chain_id += 1
return new_all_chain_features
......@@ -617,39 +688,37 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msas = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
msa_data = {}
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
for (name, start, size) in alignment_index["files"]:
filename, ext = os.path.splitext(name)
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
msa = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
msa = parsers.parse_stockholm(read_msa(start, size))
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)
......@@ -657,33 +726,35 @@ class DataPipeline:
if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": msa.deletion_matrix}
else:
continue
msas[f] = msa
return msas
msa_data[f] = data
return msa_data
def _parse_template_hit_files(
self,
alignment_dir: str,
input_sequence: str,
_alignment_index: Optional[Any] = None
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
......@@ -716,15 +787,46 @@ class DataPipeline:
return all_hits
def _process_msa_feats(
self,
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msas) == 0):
alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
"""
......@@ -732,13 +834,31 @@ class DataPipeline:
must be provided.
"""
)
msa_data["dummy"] = Msa(
[input_sequence],
[[0 for _ in input_sequence]],
["dummy"]
)
msa_features = make_msa_features(list(msas.values()))
deletion_matrix = [[0 for _ in input_sequence]]
msa_data["dummy"] = {
"msa": parsers.Msa(sequences=input_sequence, deletion_matrix=deletion_matrix, descriptions=None),
"deletion_matrix": deletion_matrix,
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
msas=msas
)
return msa_features
......@@ -746,9 +866,9 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
......@@ -761,11 +881,11 @@ class DataPipeline:
num_res = len(input_sequence)
hits = self._parse_template_hit_files(
alignment_dir,
alignment_dir,
input_sequence,
_alignment_index,
alignment_index,
)
template_features = make_template_features(
input_sequence,
hits,
......@@ -778,11 +898,11 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {
**sequence_features,
**msa_features,
**msa_features,
**template_features
}
......@@ -791,7 +911,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -812,15 +932,16 @@ class DataPipeline:
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index)
alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -831,7 +952,7 @@ class DataPipeline:
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -850,26 +971,27 @@ class DataPipeline:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
pdb_feats = make_pdb_features(
protein_object,
description,
protein_object,
description,
is_distillation=is_distillation
)
hits = self._parse_template_hits(
alignment_dir,
alignment_dir,
input_sequence,
_alignment_index
alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features}
......@@ -877,7 +999,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -886,15 +1008,15 @@ class DataPipeline:
core_str = f.read()
protein_object = protein.from_proteinnet_string(core_str)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
alignment_dir,
alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
......@@ -905,15 +1027,106 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap).
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
final_msa_obj = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert (len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
final_msa_obj.extend([parsers.Msa(sequences=msas[k], deletion_matrix=deletion_mats[k], descriptions=None)
for k in range(len(msas))])
msa_features = make_msa_features(
msas=final_msa_obj
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
class DataPipelineMultimer:
"""Runs the alignment tools and assembles the input features."""
def __init__(self,
monomer_data_pipeline: DataPipeline,
):
monomer_data_pipeline: DataPipeline,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
......@@ -926,39 +1139,41 @@ class DataPipelineMultimer:
self._monomer_data_pipeline = monomer_data_pipeline
def _process_single_chain(
self,
chain_id: str,
sequence: str,
description: str,
chain_alignment_dir: str,
is_homomer_or_monomer: bool
self,
chain_id: str,
sequence: str,
description: str,
chain_alignment_dir: str,
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(
chain_fasta_path,
chain_alignment_dir
chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir
)
chain_features.update(all_seq_msa_features)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(
chain_fasta_path,
chain_alignment_dir
)
chain_features.update(all_seq_msa_features)
return chain_features
def _all_seq_msa_features(self, fasta_path, alignment_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
#TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.a3m")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
msa = parsers.parse_a3m(uniprot_msa_string)
#msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
......@@ -968,17 +1183,17 @@ class DataPipelineMultimer:
if k in valid_feats
}
return feats
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
input_fasta_str = f.read()
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
......@@ -988,7 +1203,7 @@ class DataPipelineMultimer:
sequence_features[seq]
)
continue
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
......@@ -996,21 +1211,21 @@ class DataPipelineMultimer:
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
return np_example
\ No newline at end of file
......@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -669,7 +672,7 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch):
batch = tree_map(
lambda n: torch.tensor(n, device="cpu"),
lambda n: torch.tensor(n, device="cpu"),
batch,
np.ndarray
)
......@@ -736,6 +739,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
)
......@@ -781,10 +785,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
......@@ -831,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
if is_multimer:
base_atom_pos = [batched_gather(
pos,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
if is_multimer:
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
......@@ -864,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
......@@ -900,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
if is_multimer:
ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
else:
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
......
......@@ -103,6 +103,21 @@ def np_example_to_features(
cfg[mode],
)
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()}
......
......@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
......@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
......
......@@ -46,7 +46,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
......
......@@ -434,7 +434,7 @@ def _is_set(data: str) -> bool:
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
_zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
......
......@@ -89,6 +89,8 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("")
continue
elif line.startswith("#"):
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line
......
......@@ -128,6 +128,22 @@ def _is_after_cutoff(
return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f:
......@@ -141,7 +157,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
from_id = line[20:24].lower()
to_id = line[29:33].lower()
result[from_id] = to_id
return result
return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
......@@ -495,7 +511,7 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = True,
_zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
......@@ -912,6 +928,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
......@@ -1041,6 +1107,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -17,7 +17,7 @@ from functools import partial
import torch
import torch.nn as nn
from typing import Tuple
from typing import Tuple, Optional
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
......@@ -32,7 +32,7 @@ from openfold.model.template import (
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
......@@ -95,11 +95,22 @@ class InputEmbedder(nn.Module):
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
batch: Dict containing
......@@ -116,17 +127,20 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
"""
tf = batch["target_feat"]
ri = batch["residue_index"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
......@@ -302,7 +316,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
......@@ -344,6 +357,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
......@@ -359,6 +373,19 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(inplace_safe):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace(
self.min_bin,
self.max_bin,
......@@ -367,13 +394,6 @@ class RecyclingEmbedder(nn.Module):
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
......@@ -387,7 +407,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
z_update = add(z_update, d, inplace_safe)
return m_update, z_update
......@@ -485,7 +505,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
......@@ -544,30 +563,31 @@ class TemplateEmbedder(nn.Module):
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
_mask_trans=True,
use_lma=False,
inplace_safe=False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
......@@ -577,38 +597,64 @@ class TemplateEmbedder(nn.Module):
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
template_embeds.append(single_template_embeds)
del t
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
if (not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
t = t * (torch.sum(batch["template_mask"]) > 0)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret
......@@ -751,6 +797,8 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim,
chunk_size,
multichain_mask_2d,
use_lma=False,
inplace_safe=False
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment