Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
.vscode/
.idea/
__pycache__/
*.egg-info
build
......@@ -8,3 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
cutlass/
......@@ -7,13 +7,31 @@ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental s
A faithful but trainable PyTorch reproduction of DeepMind's
[AlphaFold 2](https://github.com/deepmind/alphafold).
## Contents
- [OpenFold](#openfold)
- [Contents](#contents)
- [Features](#features)
- [Installation (Linux)](#installation-linux)
- [Download Alignment Databases](#download-alignment-databases)
- [Inference](#inference)
- [Monomer inference](#monomer-inference)
- [Multimer Inference](#multimer-inference)
- [Soloseq Inference](#soloseq-inference)
- [Training](#training)
- [Testing](#testing)
- [Building and Using the Docker Container](#building-and-using-the-docker-container)
- [Copyright Notice](#copyright-notice)
- [Contributing](#contributing)
- [Citing this Work](#citing-this-work)
## Features
OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which
fared poorly in DeepMind's own ablation testing and is being phased out in future
DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the
source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
model ensembling, which fared poorly in DeepMind's own ablation testing and is being
phased out in future DeepMind experiments. It is omitted here for the sake of reducing
clutter. In cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
......@@ -63,7 +81,7 @@ To install:
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
## Usage
## Download Alignment Databases
If you intend to generate your own alignments, e.g. for inference, you have two
choices for downloading protein databases, depending on whether you want to use
......@@ -112,7 +130,16 @@ DeepMind's pretrained parameters, you will only be able to make changes that
do not affect the shapes of model parameters. For an example of initializing
the model, consult `run_pretrained_openfold.py`.
### Inference
## Inference
OpenFold now supports three inference modes:
- [Monomer Inference](#monomer-inference): OpenFold reproduction of AlphaFold2. Inference available with either DeepMind's pretrained parameters or OpenFold trained parameters.
- [Multimer Inference](#multimer-inference): OpenFold reproduction of AlphaFold-Multimer. Inference available with DeepMind's pre-trained parameters.
- [Single Sequence Inference (SoloSeq)](#soloseq-inference): Language Model based structure prediction, using [ESM-1b](https://github.com/facebookresearch/esm) embeddings.
More instructions for each inference mode are provided below:
### Monomer inference
To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, first download the OpenFold weights e.g.:
......@@ -131,14 +158,14 @@ python3 run_pretrained_openfold.py \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
```
......@@ -176,13 +203,6 @@ To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the
following changes:
......@@ -221,7 +241,78 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
#### SoloSeq Inference
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
### Multimer Inference
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--uniprot_database_path data/uniprot/uniprot.fasta \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
```
As with monomer inference, if you've already computed alignments for the query, you can use
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
**Upgrade from an existing OpenFold installation**
The above command requires several upgrades to existing openfold installations.
1. Re-download the alphafold parameters to get the latest
AlphaFold-Multimer v3 weights:
```bash
bash scripts/download_alphafold_params.sh openfold/resources
```
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
and [PDB SeqRes](https://www.rcsb.org/) databases:
```bash
bash scripts/download_uniprot.sh data/
```
The PDB SeqRes and PDB databases must be from the same date to avoid potential
errors during template searching. Remove the existing `data/pdb_mmcif` directory
and download both databases:
```bash
bash scripts/download_pdb_mmcif.sh data/
bash scripts/download_pdb_seqres.sh data/
```
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
```bash
bash scripts/download_uniref30.sh data/
bash scripts/download_mgnify.sh data/
```
Multimer inference can also run with the older database versions if desired.
### Soloseq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
......@@ -260,7 +351,7 @@ python3 run_pretrained_openfold.py \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
......@@ -274,7 +365,7 @@ SoloSeq allows you to use the same flags and optimizations as the MSA-based Open
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
### Training
## Training
To train the model, you will first need to precompute protein alignments.
......@@ -412,9 +503,9 @@ environment. These run components of AlphaFold and OpenFold side by side and
ensure that output activations are adequately similar. For most modules, we
target a maximum pointwise difference of `1e-4`.
## Building and using the docker container
## Building and Using the Docker Container
### Building the docker image
**Building the Docker Image**
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
......@@ -422,7 +513,7 @@ Openfold can be built as a docker container using the included dockerfile. To bu
docker build -t openfold .
```
### Running the docker container
**Running the Docker Container**
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
......@@ -462,7 +553,7 @@ python3 /opt/openfold/run_pretrained_openfold.py \
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
```
## Copyright notice
## Copyright Notice
While AlphaFold's and, by extension, OpenFold's source code is licensed under
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
......@@ -475,7 +566,7 @@ replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
If you encounter problems using OpenFold, feel free to create an issue! We also
welcome pull requests from the community.
## Citing this work
## Citing this Work
Please cite our paper:
......@@ -504,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM}
}
```
Any work that cites OpenFold should also cite AlphaFold.
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
......@@ -14,6 +14,7 @@ dependencies:
- pytorch-lightning==1.5.10
- biopython==1.79
- numpy==1.21
- pandas==2.0
- PyYAML==5.4.1
- requests
- scipy==1.7
......
This diff is collapsed.
......@@ -3,15 +3,15 @@ channels:
- conda-forge
- bioconda
dependencies:
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- openmm=7.7
- pdbfixer
- ml-collections
- PyYAML==5.4.1
- requests
- typing-extensions
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- PyYAML==5.4.1
- requests==2.26.0
- typing-extensions==3.10.0.2
from . import model
from . import utils
from . import data
from . import np
from . import resources
......
import re
import copy
import importlib
import ml_collections as mlc
......@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
path = s.split('.')
setting = config
for p in path:
setting = setting[p]
setting = setting.get(p)
return setting
......@@ -161,44 +162,70 @@ def model_config(
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name.startswith("seq"): # SINGLE SEQUENCE EMBEDDING PRESETS
c.update(seq_mode_config.copy_and_resolve_references())
if name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif "multimer" in name: # MULTIMER PRESETS
c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 252
c.data.eval.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else:
raise ValueError("Invalid model name")
if name.startswith("seq"):
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.update(seq_mode_config.copy_and_resolve_references())
if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
......@@ -380,6 +407,8 @@ config = mlc.ConfigDict(
"max_templates": 4,
"crop": False,
"crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": False,
"uniform_recycling": False,
},
......@@ -394,6 +423,8 @@ config = mlc.ConfigDict(
"max_templates": 4,
"crop": False,
"crop_size": None,
"spatial_crop_prob": None,
"interface_threshold": None,
"supervised": True,
"uniform_recycling": False,
},
......@@ -409,6 +440,8 @@ config = mlc.ConfigDict(
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"spatial_crop_prob": 0.,
"interface_threshold": None,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
......@@ -426,7 +459,6 @@ config = mlc.ConfigDict(
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
......@@ -446,6 +478,8 @@ config = mlc.ConfigDict(
"c_e": c_e,
"c_s": c_s,
"eps": eps,
"is_multimer": False,
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
},
"model": {
"_mask_trans": False,
......@@ -470,7 +504,7 @@ config = mlc.ConfigDict(
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
"template_single_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
......@@ -489,6 +523,8 @@ config = mlc.ConfigDict(
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -537,6 +573,8 @@ config = mlc.ConfigDict(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": False,
"fuse_projection_weights": False,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
"inf": 1e9,
......@@ -560,6 +598,8 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"no_column_attention": False,
"opm_first": False,
"fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
......@@ -607,6 +647,12 @@ config = mlc.ConfigDict(
"c_out": 37,
},
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1.
},
"relax": {
"max_iterations": 0, # no max
......@@ -652,6 +698,7 @@ config = mlc.ConfigDict(
"weight": 0.01,
},
"masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8,
"weight": 2.0,
},
......@@ -664,6 +711,7 @@ config = mlc.ConfigDict(
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": False,
"eps": eps, # 1e-6,
"weight": 0.0,
},
......@@ -676,12 +724,199 @@ config = mlc.ConfigDict(
"weight": 0.,
"enabled": tm_enabled,
},
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.,
"eps": eps,
"enabled": False,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True
},
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
# "all_crops_all_chains_mask": [],
# "all_crops_all_chains_positions": [],
# "all_crops_all_chains_residue_ids": [],
"assembly_num_chains": [],
"asym_id": [NUM_RES],
"atom14_atom_exists": [NUM_RES, None],
"atom37_atom_exists": [NUM_RES, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"cluster_bias_mask": [NUM_MSA_SEQ],
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
"deletion_mean": [NUM_RES],
"entity_id": [NUM_RES],
"entity_mask": [NUM_RES],
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
# "mem_peak": [],
"msa": [NUM_MSA_SEQ, NUM_RES],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_profile": [NUM_RES, None],
"num_alignments": [],
"num_templates": [],
# "queue_size": [],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"seq_length": [],
"seq_mask": [NUM_RES],
"sym_id": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
# Additional multimer features
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
},
"supervised": {
"clamp_prob": 1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"eval": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"block_delete_msa" : False,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.,
"clamp_prob": 1.,
},
},
"model": {
"input_embedder": {
"tf_dim": 21,
#"num_msa": 508,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True
},
"template": {
"template_single_embedder": {
"c_in": 34,
"c_out": c_m
},
"template_pair_embedder": {
"c_in": c_z,
"c_out": c_t,
"c_dgram": 39,
"c_aatype": 22
},
"template_pair_stack": {
"tri_mul_first": True,
"fuse_projection_weights": True
},
"c_t": c_t,
"c_z": c_z,
"use_unit_vector": True
},
"extra_msa": {
# "extra_msa_embedder": {
# "num_extra_msa": 2048
# },
"extra_msa_stack": {
"opm_first": True,
"fuse_projection_weights": True
}
},
"evoformer_stack": {
"opm_first": True,
"fuse_projection_weights": True
},
"structure_module": {
"trans_scale_factor": 20
},
"heads": {
"tm": {
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True
},
"masked_msa": {
"c_out": 22
},
},
"recycle_early_stop_tolerance": 0.5 # For training, value is -1.
},
"loss": {
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5
},
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5
}
},
"masked_msa": {
"num_classes": 22
},
"violation": {
"average_clashes": True,
"weight": 0.03 # Not finetuning
},
"tm": {
"weight": 0.1,
"enabled": True
},
"chain_center_of_mass": {
"weight": 0.05,
"enabled": True
}
}
})
seq_mode_config = mlc.ConfigDict({
"data": {
"common": {
......@@ -700,12 +935,18 @@ seq_mode_config = mlc.ConfigDict({
"seqemb_mode_enabled": True,
},
"model": {
"preembedding_embedder": { # Used in sequence embedding mode
"preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22,
"preembedding_dim": preemb_dim_size,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"extra_msa": {
"enabled": False # Disable Extra MSA Stack
},
"evoformer_stack": {
"no_column_attention": True # Turn off Evoformer's column attention
},
}
})
\ No newline at end of file
})
This diff is collapsed.
This diff is collapsed.
......@@ -23,6 +23,9 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -86,18 +89,17 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0):
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
......@@ -447,13 +449,15 @@ def make_hhblits_profile(protein):
@curry1
def make_masked_msa(protein, config, replace_fraction):
def make_masked_msa(protein, config, replace_fraction, seed):
"""Create data for BERT on raw MSA."""
device = protein["msa"].device
# Add a random amino acid uniformly.
random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0],
dtype=torch.float32,
device=protein["aatype"].device
device=device
)
categorical_probs = (
......@@ -473,11 +477,18 @@ def make_masked_msa(protein, config, replace_fraction):
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
categorical_probs, pad_shapes, value=mask_prob,
)
sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
g = None
if seed is not None:
g = torch.Generator(device=protein["msa"].device)
g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g)
mask_position = sample < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
......@@ -670,7 +681,7 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch):
batch = tree_map(
lambda n: torch.tensor(n, device="cpu"),
batch,
batch,
np.ndarray
)
out = make_atom14_masks(batch)
......@@ -736,7 +747,7 @@ def make_atom14_positions(protein):
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
)
......@@ -782,10 +793,14 @@ def make_atom14_positions(protein):
def atom37_to_frames(protein, eps=1e-8):
is_multimer = "asym_id" in protein
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
if is_multimer:
all_atom_positions = Vec3Array.from_array(all_atom_positions)
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
......@@ -832,19 +847,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
if is_multimer:
base_atom_pos = [batched_gather(
pos,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_positions.shape[:-1]),
) for pos in all_atom_positions]
base_atom_pos = Vec3Array.from_array(torch.stack(base_atom_pos, dim=-1))
else:
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
if is_multimer:
point_on_neg_x_axis = base_atom_pos[:, :, 0]
origin = base_atom_pos[:, :, 1]
point_on_xy_plane = base_atom_pos[:, :, 2]
gt_rotation = Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin)
gt_frames = Rigid3Array(gt_rotation, origin)
else:
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
......@@ -865,9 +898,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
if is_multimer:
gt_frames = gt_frames.compose_rotation(
Rot3Array.from_array(rots))
else:
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
......@@ -901,12 +938,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
if is_multimer:
ambiguity_rot = Rot3Array.from_array(residx_rigidgroup_ambiguity_rot)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
else:
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
......
from typing import Sequence
import torch
from openfold.config import NUM_RES
from openfold.data.data_transforms import curry1
from openfold.np import residue_constants as rc
from openfold.utils.tensor_utils import masked_mean
def gumbel_noise(
shape: Sequence[int],
device: torch.device,
eps=1e-6,
generator=None,
) -> torch.Tensor:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise = torch.rand(
shape, dtype=torch.float32, device=device, generator=generator
)
gumbel = -torch.log(-torch.log(uniform_noise + eps) + eps)
return gumbel
def gumbel_max_sample(logits: torch.Tensor, generator=None) -> torch.Tensor:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.nn.functional.one_hot(
torch.argmax(logits + z, dim=-1),
logits.shape[-1],
)
def gumbel_argsort_sample_idx(
logits: torch.Tensor,
generator=None
) -> torch.Tensor:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.argsort(logits + z, dim=-1, descending=True)
@curry1
def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.Tensor(
[0.05] * 20 + [0., 0.],
device=batch['msa'].device
)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * torch.nn.functional.one_hot(batch['msa'], 22)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
categorical_probs = torch.nn.functional.pad(
categorical_probs, [0,1], value=mask_prob
)
sh = batch['msa'].shape
mask_position = torch.rand(sh, device=batch['msa'].device) < replace_fraction
mask_position *= batch['msa_mask'].to(mask_position.dtype)
logits = torch.log(categorical_probs + eps)
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
bert_msa = torch.where(
mask_position,
torch.argmax(bert_msa, dim=-1),
batch['msa']
)
bert_msa *= batch['msa_mask'].to(bert_msa.dtype)
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.to(torch.float32)
else:
batch['bert_mask'] = mask_position.to(torch.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
@curry1
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device = batch["msa_mask"].device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = torch.Tensor(
[1.] * 21 + [gap_agreement_weight] + [0.],
device=device,
)
msa_mask = batch['msa_mask']
msa_one_hot = torch.nn.functional.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = torch.nn.functional.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = torch.einsum(
'mrc, nrc->nm',
extra_one_hot_masked,
weights * msa_one_hot_masked
)
cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0)
cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = torch.sum(cluster_assignment, dim=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = torch.einsum(
'nm, mc->nc',
cluster_assignment,
extra_mask * extra_deletion_matrix
)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
batch['cluster_profile'] = cluster_profile
batch['cluster_deletion_mean'] = cluster_deletion_mean
return batch
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(
batch["aatype"], 21
).to(torch.float32)
return batch
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
device = batch["msa"]
msa_1hot = torch.nn.functional.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
deletion_mean_value = (
torch.atan(
batch['cluster_deletion_mean'] / 3.) *
(2. / pi)
)[..., None]
msa_feat = torch.cat(
[
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
],
dim=-1,
)
batch["msa_feat"] = msa_feat
return batch
def build_extra_msa_feat(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa']
deletion_matrix = batch['extra_deletion_matrix']
msa_1hot = torch.nn.functional.one_hot(extra_msa, 23)
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (
(torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
)
extra_msa_mask = batch['extra_msa_mask']
catted = torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1)
return catted
@curry1
def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g = None
if seed is not None:
g = torch.Generator(device=batch["msa"].device)
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
logits = (torch.clamp(torch.sum(batch['msa_mask'], dim=-1), 0., 1.) - 1.) * inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = torch.nn.functional.pad(
batch['msa'].new_zeros(batch['msa'].shape[0] - 1),
(1, 0),
value=1.
)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * inf
index_order = gumbel_argsort_sample_idx(logits, generator=g)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:][:max_extra_msa_seq]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch["msa_profile"] = masked_mean(
batch['msa_mask'][..., None],
torch.nn.functional.one_hot(batch['msa'], 22),
dim=-3,
)
return batch
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = (diff_chain_mask[..., None] * pair_mask).bool()
min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
positions = protein["all_atom_positions"]
atom_mask = protein["all_atom_mask"]
asym_id = protein["asym_id"]
interface_residues = get_interface_residues(positions=positions,
atom_mask=atom_mask,
asym_id=asym_id,
interface_threshold=interface_threshold)
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1] - 1,
generator=generator,
device=positions.device)
target_res = interface_residues[target_res_idx]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
ca_mask = atom_mask[..., ca_idx].bool()
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
to_target_distances = ca_pairwise_dists[target_res]
break_tie = (
torch.arange(
0, to_target_distances.shape[-1], device=positions.device
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def get_contiguous_crop_idx(protein, crop_size, generator):
unique_asym_ids, chain_idxs, chain_lens = protein["asym_id"].unique(dim=-1,
return_inverse=True,
return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
_, idx_sorted = torch.sort(chain_idxs, stable=True)
cum_sum = chain_lens.cumsum(dim=0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0)
asym_offsets = idx_sorted[cum_sum]
num_budget = crop_size
num_remaining = int(protein["seq_length"])
crop_idxs = []
for idx in shuffle_idx:
chain_len = int(chain_lens[idx])
num_remaining -= chain_len
crop_size_max = min(num_budget, chain_len)
crop_size_min = min(chain_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max,
generator=generator,
device=chain_lens.device)
num_budget -= chain_crop_size
chain_start = randint(lower=0,
upper=chain_len - chain_crop_size,
generator=generator,
device=chain_lens.device)
asym_offset = asym_offsets[idx]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
return torch.concat(crop_idxs).sort().values
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
spatial_crop_prob,
interface_threshold,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = None
if seed is not None:
g = torch.Generator(device=protein["seq_length"].device)
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
v = v[slice(templates_crop_start, templates_crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
protein[k] = v
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
......@@ -20,7 +20,7 @@ import ml_collections
import numpy as np
import torch
from openfold.data import input_pipeline
from openfold.data import input_pipeline, input_pipeline_multimer
FeatureDict = Mapping[str, np.ndarray]
......@@ -80,11 +80,14 @@ def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
is_multimer: bool = False
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
......@@ -93,12 +96,20 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if is_multimer:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if mode == "train":
p = torch.rand(1).item()
......@@ -128,10 +139,15 @@ class FeaturePipeline:
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict:
# if(is_multimer and mode != "predict"):
# raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
is_multimer=is_multimer,
)
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List, Mapping
from openfold.data import msa_pairing
from openfold.np import residue_constants
import numpy as np
# TODO: Move this into the config
REQUIRED_FEATURES = frozenset({
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
})
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(np.unique(np.concatenate(
[np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
chain in chains])))
return num_unique_chains == 1
def pair_and_merge(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
) -> Mapping[str, np.ndarray]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list
)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES
)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[Mapping[str, np.ndarray]],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int
) -> List[Mapping[str, np.ndarray]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(chain: Mapping[str, np.ndarray],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> Mapping[str, np.ndarray]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']
if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain
def process_final(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _correct_msa_restypes(np_example)
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _correct_msa_restypes(np_example):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0)
np_example['msa'] = np_example['msa'].astype(np.int32)
return np_example
def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
seq_mask = (np_example['entity_id'] > 0).astype(np.float32)
np_example['msa_mask'] *= seq_mask[None]
return np_example
def _filter_features(
np_example: Mapping[str, np.ndarray]
) -> Mapping[str, np.ndarray]:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features.values():
# Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32
)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32
)
chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0
)
if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features.values():
chain_features['entity_mask'] = (
chain_features['entity_id'] != 0).astype(np.int32)
......@@ -107,7 +107,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
from openfold.data import (
data_transforms,
data_transforms_multimer,
)
def groundtruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles]
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms_multimer.sample_msa(
max_msa_clusters,
max_extra_msa,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms_multimer.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
if mode_cfg.crop:
transforms.append(
data_transforms_multimer.random_crop_to_size(
crop_size=mode_cfg.crop_size,
max_templates=mode_cfg.max_templates,
shape_schema=crop_feats,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
interface_threshold=mode_cfg.interface_threshold,
subsample_templates=mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
shape_schema=crop_feats,
msa_cluster_size=pad_msa_clusters,
extra_msa_size=mode_cfg.max_extra_msa,
num_res=mode_cfg.crop_size,
num_templates=mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id']
gt_tensors = {k: v for k, v in tensors.items() if k in gt_features}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
process_gt_feats = mode_cfg.supervised
gt_tensors = {}
if process_gt_feats:
gt_tensors = prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
if process_gt_feats:
tensors['gt_features'] = gt_tensors
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
......@@ -16,6 +16,7 @@
"""Parses the mmCIF file format."""
import collections
import dataclasses
import functools
import io
import json
import logging
......@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return {entry[index]: entry for entry in entries}
@functools.lru_cache(16, typed=False)
def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult:
......@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution)
except ValueError:
logging.info(
logging.debug(
"Invalid resolution format: %s", parsed_info[res_key]
)
......@@ -474,6 +476,20 @@ def get_atom_coords(
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd = residue_constants.atom_order['CD']
nh1 = residue_constants.atom_order['NH1']
nh2 = residue_constants.atom_order['NH2']
if(
res.get_resname() == 'ARG' and
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
(np.linalg.norm(pos[nh1] - pos[cd]) >
np.linalg.norm(pos[nh2] - pos[cd]))
):
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment