Commit 9b60d737 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor run script (WIP)

parent 71189e20
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
from datetime import date
import pickle import pickle
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist # A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import random import random
import sys import sys
from openfold.features import templates, feature_pipeline from openfold.features import templates, feature_pipeline
from openfold.features.np import data_pipeline from openfold.features.np import data_pipeline
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import time import time
import numpy as np import numpy as np
...@@ -66,7 +67,7 @@ def main(args): ...@@ -66,7 +67,7 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
data_processor = data_pipeline.DataPipeline( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path, hhsearch_binary_path=args.hhsearch_binary_path,
...@@ -76,6 +77,10 @@ def main(args): ...@@ -76,6 +77,10 @@ def main(args):
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path, small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path, pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
)
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd use_small_bfd=use_small_bfd
) )
...@@ -88,13 +93,18 @@ def main(args): ...@@ -88,13 +93,18 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config) feature_processor = feature_pipeline.FeaturePipeline(config)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
msa_output_dir = os.path.join(output_dir_base, "msas") alignment_dir = os.path.join(output_dir_base, "alignments")
if not os.path.exists(msa_output_dir): if not os.path.exists(alignment_dir):
os.makedirs(msa_output_dir) os.makedirs(alignment_dir)
print("Collecting data...") print("Collecting data...")
feature_dict = data_processor.process( alignment_runner.run_from_fasta(
input_fasta_path=args.fasta_path, msa_output_dir=msa_output_dir) args.fasta_path, alignment_dir
)
feature_dict = data_processor.process_fasta(
input_fasta_path=args.fasta_path, alignment_dir=alignment_dir
)
print("Generating features...") print("Generating features...")
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
...@@ -170,67 +180,68 @@ def main(args): ...@@ -170,67 +180,68 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--fasta_path", type=str, "fasta_path", type=str,
) )
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=os.getcwd(), 'uniref90_database_path', type=str,
help="""Name of the directory in which to output the prediction""",
required=True
) )
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", 'mgnify_database_path', type=str,
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
) )
parser.add_argument( parser.add_argument(
"--model_name", type=str, default="model_1", 'pdb70_database_path', type=str,
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
parser.add_argument( parser.add_argument(
"--param_path", type=str, default=None, 'template_mmcif_dir', type=str,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
) )
parser.add_argument( parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer' '--bfd_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits' '--small_bfd_database_path', type=str, default=None
) )
parser.add_argument( parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch' '--uniclust30_database_path', type=str, default=None
) )
parser.add_argument( parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign' '--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
) )
parser.add_argument( parser.add_argument(
'--uniref90_database_path', type=str, '--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
) )
parser.add_argument( parser.add_argument(
'--mgnify_database_path', type=str, '--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
) )
parser.add_argument( parser.add_argument(
'--bfd_database_path', type=str, '--kalign_binary_path', type=str, default='/usr/bin/kalign'
) )
parser.add_argument( parser.add_argument(
'--small_bfd_database_path', type=str, default=None '--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
) )
parser.add_argument( parser.add_argument(
'--uniclust30_database_path', type=str, default=None '--obsolete_pdbs_path', type=str, default=None
) )
parser.add_argument( parser.add_argument(
'--pdb70_database_path', type=str, "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
required=True
) )
parser.add_argument( parser.add_argument(
'--template_mmcif_dir', type=str, "--device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
) )
parser.add_argument( parser.add_argument(
'--max_template_date', type=str, "--model_name", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
parser.add_argument( parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None "--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
) )
parser.add_argument( parser.add_argument(
'--preset', type=str, default='full_dbs', '--preset', type=str, default='full_dbs',
...@@ -248,4 +259,11 @@ if __name__ == "__main__": ...@@ -248,4 +259,11 @@ if __name__ == "__main__":
"params_" + args.model_name + ".npz" "params_" + args.model_name + ".npz"
) )
if(args.bfd_database_path is None and
args.small_bfd_database_path is None):
raise ValueError(
"At least one of --bfd_database_path or --small_bfd_database_path"
"must be specified"
)
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment