Commit 15850092 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added multimer inference to README

parent d7c11537
......@@ -10,10 +10,10 @@ A faithful but trainable PyTorch reproduction of DeepMind's
## Features
OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which
fared poorly in DeepMind's own ablation testing and is being phased out in future
DeepMind experiments. It is omitted here for the sake of reducing clutter. In
cases where the *Nature* paper differs from the source, we always defer to the
source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
model ensembling, which fared poorly in DeepMind's own ablation testing and is being
phased out in future DeepMind experiments. It is omitted here for the sake of reducing
clutter. In cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
......@@ -142,14 +142,14 @@ python3 run_pretrained_openfold.py \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
```
......@@ -187,13 +187,6 @@ To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the
following changes:
......@@ -232,6 +225,74 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
#### Multimer Inference
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--uniprot_database_path data/uniprot/uniprot.fasta \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
```
As with monomer inference, if you've already computed alignments for the query, you can use
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
##### Upgrades
The above command requires several upgrades to existing openfold installations.
1. Re-download the alphafold parameters to get the latest
AlphaFold-Multimer v3 weights:
```bash
bash scripts/download_alphafold_params.sh openfold/resources
```
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
and [PDB SeqRes](https://www.rcsb.org/) databases:
```bash
bash scripts/download_uniprot.sh data/
```
The PDB SeqRes and PDB databases must be from the same date to avoid potential
errors during template searching. Remove the existing `data/pdb_mmcif` directory
and download both databases:
```bash
bash scripts/download_pdb_mmcif.sh data/
bash scripts/download_pdb_seqres.sh data/
```
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
```bash
bash scripts/download_uniref30.sh data/
bash scripts/download_mgnify.sh data/
```
Multimer inference can also run with the older database versions if desired.
#### SoloSeq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
......
......@@ -714,9 +714,7 @@ config = mlc.ConfigDict(
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True,
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16_output": False
"is_multimer": True
},
"data": {
"common": {
......@@ -766,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20,
"max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
......@@ -860,7 +858,7 @@ multimer_config_update = mlc.ConfigDict({
"c_out": 22
},
},
"recycle_early_stop_tolerance": 0.5
"recycle_early_stop_tolerance": 0.5 # For training, value is -1.
},
"loss": {
"fape": {
......
......@@ -556,6 +556,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
*args, **kwargs
) -> bool:
# Hard filters
resolution = cache_entry.get("resolution", None)
......@@ -569,6 +570,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
*args, **kwargs
) -> float:
# Stochastic filters
probabilities = []
......@@ -677,9 +679,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
is_distillation: bool,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200,
*args, **kwargs
) -> bool:
"""
Implement multimer training filtering criteria described in
......@@ -692,12 +696,13 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
max_resolution=max_resolution),
all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
(is_distillation and aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop))])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
*args, **kwargs
) -> list:
# Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes")
......@@ -710,6 +715,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
is_distillation = dataset.treat_pdb_as_distillation
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
mmcif_data_cache = dataset.mmcif_data_cache
while True:
......@@ -719,7 +725,8 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if not self.deterministic_train_filter(mmcif_data_cache_entry):
if not self.deterministic_train_filter(cache_entry=mmcif_data_cache_entry,
is_distillation=is_distillation):
continue
chain_probs = self.get_stochastic_train_filter_prob(
......
......@@ -17,24 +17,19 @@ import logging
import math
import numpy as np
import os
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
update_timings, relax_protein
import pickle
import random
import time
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import pickle
import random
import time
import torch
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
if (
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
......@@ -44,20 +39,17 @@ if(
torch.set_grad_enabled(False)
from openfold.config import model_config
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.trace_utils import (
pad_feature_dict_seq,
trace_model_,
)
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args
......@@ -65,18 +57,18 @@ from scripts.utils import add_data_args
TRACING_INTERVAL = 50
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
alignment_dir,
os.path.join(alignment_dir, tag),
)
if args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir):
logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir)
......@@ -91,6 +83,19 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
embedding_generator = EmbeddingGenerator()
embedding_generator.run(tmp_fasta_path, alignment_dir)
else:
is_multimer = "multimer" in args.config_preset
if is_multimer:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -100,7 +105,9 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
no_cpus=args.cpus,
template_searcher=template_searcher,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task
)
alignment_runner.run(
......@@ -161,12 +168,13 @@ def generate_feature_dict(
return feature_dict
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args):
# Create the output directory
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
if args.config_preset.startswith("seq"):
......@@ -174,24 +182,15 @@ def main(args):
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model):
if(not config.data.predict.fixed_size):
if args.trace_model:
if not config.data.predict.fixed_size:
raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config"
)
is_multimer = "multimer" in args.config_preset
if(is_multimer):
if(not args.use_precomputed_alignments):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
if is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -201,14 +200,6 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -218,28 +209,11 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path
)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(is_multimer):
if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
......@@ -247,7 +221,7 @@ def main(args):
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
random_seed = random.randrange(2 ** 32)
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
......@@ -292,6 +266,7 @@ def main(args):
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
for model, output_directory in model_generator:
cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets:
......@@ -300,10 +275,10 @@ def main(args):
output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None):
if feature_dict is None:
feature_dict = generate_feature_dict(
tags,
seqs,
......@@ -312,7 +287,7 @@ def main(args):
args,
)
if(args.trace_model):
if args.trace_model:
n = feature_dict["aatype"].shape[-2]
rounded_seqlen = round_up_seqlen(n)
feature_dict = pad_feature_dict_seq(
......@@ -326,12 +301,12 @@ def main(args):
)
processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items()
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
if (args.trace_model):
if (rounded_seqlen > cur_tracing_interval):
if args.trace_model:
if rounded_seqlen > cur_tracing_interval:
logger.info(
f"Tracing model at {rounded_seqlen} residues..."
)
......@@ -380,7 +355,8 @@ def main(args):
if not args.skip_relaxation:
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
args.cif_output)
if args.save_outputs:
output_dict_path = os.path.join(
......@@ -482,13 +458,13 @@ if __name__ == "__main__":
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
if args.jax_param_path is None and args.openfold_checkpoint_path is None:
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
if args.model_device == "cpu" and torch.cuda.is_available():
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
......
......@@ -116,13 +116,13 @@ def parse_and_align(files, alignment_runner, args):
def main(args):
# Build the alignment tool runner
if (args.hmmsearch_binary_path is not None):
if args.hmmsearch_binary_path is not None and args.pdb_seqres_database_path is not None:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
elif (args.hhsearch_binary_path is not None):
elif args.hhsearch_binary_path is not None and args.pdb70_database_path is not None:
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment