# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from datetime import date import logging import numpy as np import os # A hack to get OpenMM and PyTorch to peacefully coexist os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL" import pickle import random import sys import time import torch from openfold.config import model_config from openfold.data import templates, feature_pipeline, data_pipeline from openfold.model.model import AlphaFold from openfold.np import residue_constants, protein import openfold.np.relax.relax as relax from openfold.utils.import_weights import ( import_jax_weights_, ) from openfold.utils.tensor_utils import ( tensor_tree_map, ) from scripts.utils import add_data_args def main(args): config = model_config(args.model_name) model = AlphaFold(config.model) model = model.eval() import_jax_weights_(model, args.param_path) model = model.to(args.model_device) # FEATURE COLLECTION AND PROCESSING num_ensemble = 1 template_featurizer = templates.TemplateHitFeaturizer( mmcif_dir=args.template_mmcif_dir, max_template_date=args.max_template_date, max_hits=args.max_template_hits, kalign_binary_path=args.kalign_binary_path, release_dates_path=None, obsolete_pdbs_path=args.obsolete_pdbs_path ) use_small_bfd=(args.bfd_database_path is None) alignment_runner = data_pipeline.AlignmentRunner( jackhmmer_binary_path=args.jackhmmer_binary_path, hhblits_binary_path=args.hhblits_binary_path, hhsearch_binary_path=args.hhsearch_binary_path, uniref90_database_path=args.uniref90_database_path, mgnify_database_path=args.mgnify_database_path, bfd_database_path=args.bfd_database_path, uniclust30_database_path=args.uniclust30_database_path, small_bfd_database_path=args.small_bfd_database_path, pdb70_database_path=args.pdb70_database_path, use_small_bfd=use_small_bfd, no_cpus=args.cpus, ) data_processor = data_pipeline.DataPipeline( template_featurizer=template_featurizer, use_small_bfd=use_small_bfd ) output_dir_base = args.output_dir random_seed = args.data_random_seed if random_seed is None: random_seed = random.randrange(sys.maxsize) config.data.predict.num_ensemble = num_ensemble feature_processor = feature_pipeline.FeaturePipeline(config.data) if not os.path.exists(output_dir_base): os.makedirs(output_dir_base) alignment_dir = os.path.join(output_dir_base, "alignments") if not os.path.exists(alignment_dir): os.makedirs(alignment_dir) logging.info("Generating features...") alignment_runner.run( args.fasta_path, alignment_dir ) feature_dict = data_processor.process_fasta( fasta_path=args.fasta_path, alignment_dir=alignment_dir ) processed_feature_dict = feature_processor.process_features( feature_dict, mode='predict', ) logging.info("Executing model...") batch = processed_feature_dict with torch.no_grad(): batch = { k:torch.as_tensor(v, device=args.model_device) for k,v in batch.items() } t = time.time() out = model(batch) logging.info(f"Inference time: {time.time() - t}") # Toss out the recycling dimensions --- we don't need them anymore batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) out = tensor_tree_map(lambda x: np.array(x.cpu()), out) plddt = out["plddt"] mean_plddt = np.mean(plddt) plddt_b_factors = np.repeat( plddt[..., None], residue_constants.atom_type_num, axis=-1 ) unrelaxed_protein = protein.from_prediction( features=batch, result=out, b_factors=plddt_b_factors ) amber_relaxer = relax.AmberRelaxation( **config.relax ) # Relax the prediction. t = time.time() relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) logging.info(f"Relaxation time: {time.time() - t}") # Save the relaxed PDB. relaxed_output_path = os.path.join( args.output_dir, f'relaxed_{args.model_name}.pdb' ) with open(relaxed_output_path, 'w') as f: f.write(relaxed_pdb_str) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "fasta_path", type=str, ) add_data_args(parser) parser.add_argument( "--output_dir", type=str, default=os.getcwd(), help="""Name of the directory in which to output the prediction""", required=True ) parser.add_argument( "--model_device", type=str, default="cpu", help="""Name of the device on which to run the model. Any valid torch device name is accepted (e.g. "cpu", "cuda:0")""" ) parser.add_argument( "--model_name", type=str, default="model_1", help="""Name of a model config. Choose one of model_{1-5} or model_{1-5}_ptm, as defined on the AlphaFold GitHub.""" ) parser.add_argument( "--param_path", type=str, default=None, help="""Path to model parameters. If None, parameters are selected automatically according to the model name from openfold/resources/params""" ) parser.add_argument( "--cpus", type=int, default=4, help="""Number of CPUs to use to run alignment tools""" ) parser.add_argument( '--preset', type=str, default='full_dbs', choices=('reduced_dbs', 'full_dbs') ) parser.add_argument( '--data_random_seed', type=str, default=None ) args = parser.parse_args() if(args.param_path is None): args.param_path = os.path.join( "openfold", "resources", "params", "params_" + args.model_name + ".npz" ) if(args.bfd_database_path is None and args.small_bfd_database_path is None): raise ValueError( "At least one of --bfd_database_path or --small_bfd_database_path" "must be specified" ) main(args)