"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "c4213899fedce25f59262cfdc944d0dfa0feecea"
Commit 9337ed13 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Integrated the feature processing pipeline in the 'main' runner script.

parent c8e54318
...@@ -19,6 +19,11 @@ import pickle ...@@ -19,6 +19,11 @@ import pickle
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist # A hack to get OpenMM and PyTorch to peacefully coexist
import random
import sys
from openfold.features import data_pipeline, templates, feature_pipeline
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL" os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import time import time
...@@ -41,6 +46,8 @@ from openfold.utils.tensor_utils import ( ...@@ -41,6 +46,8 @@ from openfold.utils.tensor_utils import (
FEAT_PATH = "tests/test_data/sample_feats.pickle" FEAT_PATH = "tests/test_data/sample_feats.pickle"
MAX_TEMPLATE_HITS = 20
def main(args): def main(args):
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config.model) model = AlphaFold(config.model)
...@@ -48,9 +55,60 @@ def main(args): ...@@ -48,9 +55,60 @@ def main(args):
import_jax_weights_(model, args.param_path) import_jax_weights_(model, args.param_path)
model = model.to(args.device) model = model.to(args.device)
with open(FEAT_PATH, "rb") as f: # FEATURE COLLECTION AND PROCESSING
batch = pickle.load(f) use_small_bfd = args.preset == "reduced_dbs"
num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path)
data_processor = data_pipeline.DataPipeline(
jackhammer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd
)
output_dir_base = args.output_dir
random_seed = args.random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
config.data.eval.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
msa_output_dir = os.path.join(output_dir_base, "msas")
if not os.path.exists(msa_output_dir):
os.makedirs(msa_output_dir)
print("Collecting data...")
feature_dict = data_processor.process(
input_fasta_path=args.fasta_path, msa_output_dir=msa_output_dir)
# Output the features
features_output_path = os.path.join(output_dir_base, 'features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)
print("Generating features...")
processed_feature_dict = feature_processor.process_features(feature_dict, random_seed)
with open(os.path.join(output_dir_base, 'processed_feats.pkl'), 'wb') as f:
pickle.dump(processed_feature_dict, f, protocol=4)
print("Executing model...")
batch = processed_feature_dict
with torch.no_grad(): with torch.no_grad():
batch = { batch = {
k:torch.as_tensor(v, device=args.device) k:torch.as_tensor(v, device=args.device)
...@@ -117,6 +175,14 @@ def main(args): ...@@ -117,6 +175,14 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--fasta_path", type=str, default=None, required=True
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
required=True
)
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", "--device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch help="""Name of the device on which to run the model. Any valid torch
...@@ -127,16 +193,57 @@ if __name__ == "__main__": ...@@ -127,16 +193,57 @@ if __name__ == "__main__":
help="""Name of a model config. Choose one of model_{1-5} or help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub.""" model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction"""
)
parser.add_argument( parser.add_argument(
"--param_path", type=str, default=None, "--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from automatically according to the model name from
openfold/resources/params""" openfold/resources/params"""
) )
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument('--uniref90_database_path', type=str, default=None, required=True
)
parser.add_argument(
'--mgnify_database_path', type=str, default=None, required=True
)
parser.add_argument(
'--bfd_database_path', type=str, default=None, required=True
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--uniclust30_database_path', type=str, default=None
)
parser.add_argument(
'--pdb70_database_path', type=str, default=None, required=True
)
parser.add_argument(
'--template_mmcif_dir', type=str, default=None, required=True
)
parser.add_argument(
'--max_template_date', type=str, default=None, required=True
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
parser.add_argument(
'--preset', type=str, default='full_dbs', required=True,
choices=('reduced_dbs', 'full_dbs')
)
parser.add_argument(
'--random_seed', type=str, default=None
)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment