Commit 15850092 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added multimer inference to README

parent d7c11537
...@@ -10,10 +10,10 @@ A faithful but trainable PyTorch reproduction of DeepMind's ...@@ -10,10 +10,10 @@ A faithful but trainable PyTorch reproduction of DeepMind's
## Features ## Features
OpenFold carefully reproduces (almost) all of the features of the original open OpenFold carefully reproduces (almost) all of the features of the original open
source inference code (v2.0.1). The sole exception is model ensembling, which source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
fared poorly in DeepMind's own ablation testing and is being phased out in future model ensembling, which fared poorly in DeepMind's own ablation testing and is being
DeepMind experiments. It is omitted here for the sake of reducing clutter. In phased out in future DeepMind experiments. It is omitted here for the sake of reducing
cases where the *Nature* paper differs from the source, we always defer to the clutter. In cases where the *Nature* paper differs from the source, we always defer to the
latter. latter.
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed, OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
...@@ -142,14 +142,14 @@ python3 run_pretrained_openfold.py \ ...@@ -142,14 +142,14 @@ python3 run_pretrained_openfold.py \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0" \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_ptm" \ --config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
``` ```
...@@ -187,13 +187,6 @@ To enable it, add `--trace_model` to the inference command. ...@@ -187,13 +187,6 @@ To enable it, add `--trace_model` to the inference command.
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention) To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues. in the config. Note that it appears to work best for sequences with < 1000 residues.
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch
instead.
To minimize memory usage during inference on long sequences, consider the To minimize memory usage during inference on long sequences, consider the
following changes: following changes:
...@@ -232,6 +225,74 @@ efficent AlphaFold-Multimer more than double the time. Use the ...@@ -232,6 +225,74 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option `--long_sequence_inference` command line option
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
#### Multimer Inference
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
--uniprot_database_path data/uniprot/uniprot.fasta \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
```
As with monomer inference, if you've already computed alignments for the query, you can use
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
##### Upgrades
The above command requires several upgrades to existing openfold installations.
1. Re-download the alphafold parameters to get the latest
AlphaFold-Multimer v3 weights:
```bash
bash scripts/download_alphafold_params.sh openfold/resources
```
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
and [PDB SeqRes](https://www.rcsb.org/) databases:
```bash
bash scripts/download_uniprot.sh data/
```
The PDB SeqRes and PDB databases must be from the same date to avoid potential
errors during template searching. Remove the existing `data/pdb_mmcif` directory
and download both databases:
```bash
bash scripts/download_pdb_mmcif.sh data/
bash scripts/download_pdb_seqres.sh data/
```
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
```bash
bash scripts/download_uniref30.sh data/
bash scripts/download_mgnify.sh data/
```
Multimer inference can also run with the older database versions if desired.
#### SoloSeq Inference #### SoloSeq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference. To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
......
...@@ -714,9 +714,7 @@ config = mlc.ConfigDict( ...@@ -714,9 +714,7 @@ config = mlc.ConfigDict(
multimer_config_update = mlc.ConfigDict({ multimer_config_update = mlc.ConfigDict({
"globals": { "globals": {
"is_multimer": True, "is_multimer": True
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16_output": False
}, },
"data": { "data": {
"common": { "common": {
...@@ -766,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -766,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({
], ],
"true_msa": [NUM_MSA_SEQ, NUM_RES] "true_msa": [NUM_MSA_SEQ, NUM_RES]
}, },
"max_recycling_iters": 20, "max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [ "unsupervised_features": [
"aatype", "aatype",
"residue_index", "residue_index",
...@@ -860,7 +858,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -860,7 +858,7 @@ multimer_config_update = mlc.ConfigDict({
"c_out": 22 "c_out": 22
}, },
}, },
"recycle_early_stop_tolerance": 0.5 "recycle_early_stop_tolerance": 0.5 # For training, value is -1.
}, },
"loss": { "loss": {
"fape": { "fape": {
......
...@@ -556,6 +556,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -556,6 +556,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
cache_entry: Any, cache_entry: Any,
max_resolution: float = 9., max_resolution: float = 9.,
max_single_aa_prop: float = 0.8, max_single_aa_prop: float = 0.8,
*args, **kwargs
) -> bool: ) -> bool:
# Hard filters # Hard filters
resolution = cache_entry.get("resolution", None) resolution = cache_entry.get("resolution", None)
...@@ -569,6 +570,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -569,6 +570,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
cache_entry: Any, cache_entry: Any,
*args, **kwargs
) -> float: ) -> float:
# Stochastic filters # Stochastic filters
probabilities = [] probabilities = []
...@@ -677,9 +679,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -677,9 +679,11 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
@staticmethod @staticmethod
def deterministic_train_filter( def deterministic_train_filter(
cache_entry: Any, cache_entry: Any,
is_distillation: bool,
max_resolution: float = 9., max_resolution: float = 9.,
max_single_aa_prop: float = 0.8, max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200, minimum_number_of_residues: int = 200,
*args, **kwargs
) -> bool: ) -> bool:
""" """
Implement multimer training filtering criteria described in Implement multimer training filtering criteria described in
...@@ -692,12 +696,13 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -692,12 +696,13 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
max_resolution=max_resolution), max_resolution=max_resolution),
all_seq_len_filter(seqs=seqs, all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues), minimum_number_of_residues=minimum_number_of_residues),
aa_count_filter(seqs=seqs, (is_distillation and aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)]) max_single_aa_prop=max_single_aa_prop))])
@staticmethod @staticmethod
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
cache_entry: Any, cache_entry: Any,
*args, **kwargs
) -> list: ) -> list:
# Stochastic filters # Stochastic filters
cluster_sizes = cache_entry.get("cluster_sizes") cluster_sizes = cache_entry.get("cluster_sizes")
...@@ -710,6 +715,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -710,6 +715,7 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
def looped_samples(self, dataset_idx): def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx]) max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
is_distillation = dataset.treat_pdb_as_distillation
idx_iter = self.looped_shuffled_dataset_idx(len(dataset)) idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
mmcif_data_cache = dataset.mmcif_data_cache mmcif_data_cache = dataset.mmcif_data_cache
while True: while True:
...@@ -719,7 +725,8 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -719,7 +725,8 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx) mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id] mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if not self.deterministic_train_filter(mmcif_data_cache_entry): if not self.deterministic_train_filter(cache_entry=mmcif_data_cache_entry,
is_distillation=is_distillation):
continue continue
chain_probs = self.get_stochastic_train_filter_prob( chain_probs = self.get_stochastic_train_filter_prob(
......
...@@ -17,24 +17,19 @@ import logging ...@@ -17,24 +17,19 @@ import logging
import math import math
import numpy as np import numpy as np
import os import os
import pickle
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \ import random
update_timings, relax_protein import time
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO) logger.setLevel(level=logging.INFO)
import pickle
import random
import time
import torch import torch
torch_versions = torch.__version__.split(".") torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0]) torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1]) torch_minor_version = int(torch_versions[1])
if( if (
torch_major_version > 1 or torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12) (torch_major_version == 1 and torch_minor_version >= 12)
): ):
...@@ -44,20 +39,17 @@ if( ...@@ -44,20 +39,17 @@ if(
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from openfold.config import model_config from openfold.config import model_config
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein from openfold.data.tools import hhsearch, hmmsearch
import openfold.np.relax.relax as relax from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
from openfold.utils.tensor_utils import ( prep_output, relax_protein)
tensor_tree_map, from openfold.utils.tensor_utils import tensor_tree_map
)
from openfold.utils.trace_utils import ( from openfold.utils.trace_utils import (
pad_feature_dict_seq, pad_feature_dict_seq,
trace_model_, trace_model_,
) )
from scripts.precompute_embeddings import EmbeddingGenerator from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args from scripts.utils import add_data_args
...@@ -65,18 +57,18 @@ from scripts.utils import add_data_args ...@@ -65,18 +57,18 @@ from scripts.utils import add_data_args
TRACING_INTERVAL = 50 TRACING_INTERVAL = 50
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer): def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs): for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join( local_alignment_dir = os.path.join(
alignment_dir, alignment_dir,
os.path.join(alignment_dir, tag), os.path.join(alignment_dir, tag),
) )
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)): if args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
...@@ -91,6 +83,19 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer): ...@@ -91,6 +83,19 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
embedding_generator = EmbeddingGenerator() embedding_generator = EmbeddingGenerator()
embedding_generator.run(tmp_fasta_path, alignment_dir) embedding_generator.run(tmp_fasta_path, alignment_dir)
else: else:
is_multimer = "multimer" in args.config_preset
if is_multimer:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -100,7 +105,9 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer): ...@@ -100,7 +105,9 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
uniref30_database_path=args.uniref30_database_path, uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path, uniprot_database_path=args.uniprot_database_path,
no_cpus=args.cpus, template_searcher=template_searcher,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task
) )
alignment_runner.run( alignment_runner.run(
...@@ -161,12 +168,13 @@ def generate_feature_dict( ...@@ -161,12 +168,13 @@ def generate_feature_dict(
return feature_dict return feature_dict
def list_files_with_extensions(dir, extensions): def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)] return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args): def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.config_preset.startswith("seq"): if args.config_preset.startswith("seq"):
...@@ -174,24 +182,15 @@ def main(args): ...@@ -174,24 +182,15 @@ def main(args):
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if args.trace_model:
if(not config.data.predict.fixed_size): if not config.data.predict.fixed_size:
raise ValueError( raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config" "Tracing requires that fixed_size mode be enabled in the config"
) )
is_multimer = "multimer" in args.config_preset is_multimer = "multimer" in args.config_preset
if(is_multimer): if is_multimer:
if(not args.use_precomputed_alignments):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer( template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -201,14 +200,6 @@ def main(args): ...@@ -201,14 +200,6 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
else: else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer( template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -218,28 +209,11 @@ def main(args): ...@@ -218,28 +209,11 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(is_multimer): if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer( data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor, monomer_data_pipeline=data_processor,
) )
...@@ -247,7 +221,7 @@ def main(args): ...@@ -247,7 +221,7 @@ def main(args):
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(2**32) random_seed = random.randrange(2 ** 32)
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
...@@ -292,6 +266,7 @@ def main(args): ...@@ -292,6 +266,7 @@ def main(args):
args.openfold_checkpoint_path, args.openfold_checkpoint_path,
args.jax_param_path, args.jax_param_path,
args.output_dir) args.output_dir)
for model, output_directory in model_generator: for model, output_directory in model_generator:
cur_tracing_interval = 0 cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets: for (tag, tags), seqs in sorted_targets:
...@@ -300,10 +275,10 @@ def main(args): ...@@ -300,10 +275,10 @@ def main(args):
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed # Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer) precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None) feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None): if feature_dict is None:
feature_dict = generate_feature_dict( feature_dict = generate_feature_dict(
tags, tags,
seqs, seqs,
...@@ -312,7 +287,7 @@ def main(args): ...@@ -312,7 +287,7 @@ def main(args):
args, args,
) )
if(args.trace_model): if args.trace_model:
n = feature_dict["aatype"].shape[-2] n = feature_dict["aatype"].shape[-2]
rounded_seqlen = round_up_seqlen(n) rounded_seqlen = round_up_seqlen(n)
feature_dict = pad_feature_dict_seq( feature_dict = pad_feature_dict_seq(
...@@ -326,12 +301,12 @@ def main(args): ...@@ -326,12 +301,12 @@ def main(args):
) )
processed_feature_dict = { processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device) k: torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items() for k, v in processed_feature_dict.items()
} }
if (args.trace_model): if args.trace_model:
if (rounded_seqlen > cur_tracing_interval): if rounded_seqlen > cur_tracing_interval:
logger.info( logger.info(
f"Tracing model at {rounded_seqlen} residues..." f"Tracing model at {rounded_seqlen} residues..."
) )
...@@ -380,7 +355,8 @@ def main(args): ...@@ -380,7 +355,8 @@ def main(args):
if not args.skip_relaxation: if not args.skip_relaxation:
# Relax the prediction. # Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output) relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
args.cif_output)
if args.save_outputs: if args.save_outputs:
output_dict_path = os.path.join( output_dict_path = os.path.join(
...@@ -482,13 +458,13 @@ if __name__ == "__main__": ...@@ -482,13 +458,13 @@ if __name__ == "__main__":
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None): if args.jax_param_path is None and args.openfold_checkpoint_path is None:
args.jax_param_path = os.path.join( args.jax_param_path = os.path.join(
"openfold", "resources", "params", "openfold", "resources", "params",
"params_" + args.config_preset + ".npz" "params_" + args.config_preset + ".npz"
) )
if(args.model_device == "cpu" and torch.cuda.is_available()): if args.model_device == "cpu" and torch.cuda.is_available():
logging.warning( logging.warning(
"""The model is being run on CPU. Consider specifying """The model is being run on CPU. Consider specifying
--model_device for better performance""" --model_device for better performance"""
......
...@@ -116,13 +116,13 @@ def parse_and_align(files, alignment_runner, args): ...@@ -116,13 +116,13 @@ def parse_and_align(files, alignment_runner, args):
def main(args): def main(args):
# Build the alignment tool runner # Build the alignment tool runner
if (args.hmmsearch_binary_path is not None): if args.hmmsearch_binary_path is not None and args.pdb_seqres_database_path is not None:
template_searcher = hmmsearch.Hmmsearch( template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path, binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path, hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path, database_path=args.pdb_seqres_database_path,
) )
elif (args.hhsearch_binary_path is not None): elif args.hhsearch_binary_path is not None and args.pdb70_database_path is not None:
template_searcher = hhsearch.HHSearch( template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path, binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path], databases=[args.pdb70_database_path],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment