"docs/vscode:/vscode.git/clone" did not exist on "b81f709fb6c7f21ba06f6d79fc17705174d2e024"
Unverified Commit d3df8e69 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

add multimer inference (#59)

* add multimer inference

* add dada pipeline
parent 444c548a
...@@ -19,6 +19,7 @@ import contextlib ...@@ -19,6 +19,7 @@ import contextlib
import dataclasses import dataclasses
import datetime import datetime
import json import json
import copy
from multiprocessing import cpu_count from multiprocessing import cpu_count
import tempfile import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
...@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict: ...@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def make_template_features( def make_template_features(
input_sequence: str, input_sequence: str,
hits: Sequence[Any], hits: Sequence[Any],
template_featurizer: Any, template_featurizer: Union[hhsearch.HHSearch, hmmsearch.Hmmsearch],
query_pdb_code: Optional[str] = None, query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None, query_release_date: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
...@@ -64,12 +65,18 @@ def make_template_features( ...@@ -64,12 +65,18 @@ def make_template_features(
if(len(hits_cat) == 0 or template_featurizer is None): if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence)) template_features = empty_template_feats(len(input_sequence))
else: else:
if type(template_featurizer) == hhsearch.HHSearch:
templates_result = template_featurizer.get_templates( templates_result = template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=query_pdb_code, query_pdb_code=query_pdb_code,
query_release_date=query_release_date, query_release_date=query_release_date,
hits=hits_cat, hits=hits_cat,
) )
else:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
hits=hits_cat,
)
template_features = templates_result.features template_features = templates_result.features
# The template featurizer doesn't format empty template features # The template featurizer doesn't format empty template features
...@@ -242,7 +249,7 @@ def run_msa_tool( ...@@ -242,7 +249,7 @@ def run_msa_tool(
if(msa_format == "sto" and max_sto_sequences is not None): if(msa_format == "sto" and max_sto_sequences is not None):
result = msa_runner.query(fasta_path, max_sto_sequences)[0] result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else: else:
result = msa_runner.query(fasta_path)[0] result = msa_runner.query(fasta_path)
with open(msa_out_path, "w") as f: with open(msa_out_path, "w") as f:
f.write(result[msa_format]) f.write(result[msa_format])
...@@ -262,7 +269,6 @@ class AlignmentRunner: ...@@ -262,7 +269,6 @@ class AlignmentRunner:
bfd_database_path: Optional[str] = None, bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None, uniclust30_database_path: Optional[str] = None,
pdb70_database_path: Optional[str] = None, pdb70_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None, use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None, no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000, uniref_max_hits: int = 10000,
...@@ -447,6 +453,225 @@ class AlignmentRunner: ...@@ -447,6 +453,225 @@ class AlignmentRunner:
f.write(hhblits_bfd_uniclust_result["a3m"]) f.write(hhblits_bfd_uniclust_result["a3m"])
class AlignmentRunnerMultimer(AlignmentRunner):
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: Optional[str] = None,
hhblits_binary_path: Optional[str] = None,
uniref90_database_path: Optional[str] = None,
mgnify_database_path: Optional[str] = None,
bfd_database_path: Optional[str] = None,
uniclust30_database_path: Optional[str] = None,
uniprot_database_path: Optional[str] = None,
template_searcher: Optional[TemplateSearcher] = None,
use_small_bfd: Optional[bool] = None,
no_cpus: Optional[int] = None,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
uniprot_max_hits: int = 50000,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
# super().__init__()
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
"dbs": [
uniref90_database_path,
mgnify_database_path,
bfd_database_path if use_small_bfd else None,
uniprot_database_path,
],
},
"hhblits": {
"binary": hhblits_binary_path,
"dbs": [
bfd_database_path if not use_small_bfd else None,
],
},
}
for name, dic in db_map.items():
binary, dbs = dic["binary"], dic["dbs"]
if(binary is None and not all([x is None for x in dbs])):
raise ValueError(
f"{name} DBs provided but {name} binary is None"
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
self.uniprot_max_hits = uniprot_max_hits
self.use_small_bfd = use_small_bfd
if(no_cpus is None):
no_cpus = cpu_count()
self.jackhmmer_uniref90_runner = None
if(jackhmmer_binary_path is not None and
uniref90_database_path is not None
):
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_small_bfd_runner = None
self.hhblits_bfd_uniclust_runner = None
if(bfd_database_path is not None):
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=bfd_database_path,
n_cpu=no_cpus,
)
else:
dbs = [bfd_database_path]
if(uniclust30_database_path is not None):
dbs.append(uniclust30_database_path)
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=dbs,
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = None
if(mgnify_database_path is not None):
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.jackhmmer_uniprot_runner = None
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
)
if(template_searcher is not None and
self.jackhmmer_uniref90_runner is None
):
raise ValueError(
"Uniref90 runner must be specified to run template search"
)
self.template_searcher = template_searcher
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
if(self.jackhmmer_uniref90_runner is not None):
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.sto")
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
fasta_path=fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
max_sto_sequences=self.uniref_max_hits,
)
template_msa = jackhmmer_uniref90_result["sto"]
template_msa = parsers.deduplicate_stockholm_msa(template_msa)
template_msa = parsers.remove_empty_columns_from_stockholm_msa(
template_msa
)
if(self.template_searcher is not None):
if(self.template_searcher.input_format == "sto"):
pdb_templates_result = self.template_searcher.query(
template_msa,
output_dir=output_dir
)
elif(self.template_searcher.input_format == "a3m"):
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
template_msa
)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m,
output_dir=output_dir
)
else:
fmt = self.template_searcher.input_format
raise ValueError(
f"Unrecognized template input format: {fmt}"
)
if(self.jackhmmer_mgnify_runner is not None):
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.sto")
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
fasta_path=fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
max_sto_sequences=self.mgnify_max_hits
)
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="sto",
)
elif(self.hhblits_bfd_uniclust_runner is not None):
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
fasta_path=fasta_path,
msa_out_path=bfd_out_path,
msa_format="a3m",
)
if(self.jackhmmer_uniprot_runner is not None):
uniprot_out_path = os.path.join(output_dir, 'uniprot_hits.sto')
result = run_msa_tool(
self.jackhmmer_uniprot_runner,
fasta_path=fasta_path,
msa_out_path=uniprot_out_path,
msa_format='sto',
max_sto_sequences=self.uniprot_max_hits,
)
@contextlib.contextmanager @contextlib.contextmanager
def temp_fasta_file(fasta_str: str): def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
...@@ -722,7 +947,12 @@ class DataPipeline: ...@@ -722,7 +947,12 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index,
)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -893,8 +1123,8 @@ class DataPipelineMultimer: ...@@ -893,8 +1123,8 @@ class DataPipelineMultimer:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp: with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read() uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string) msa, deletion_matrix, _ = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa]) all_seq_features = make_msa_features(msa, deletion_matrix)
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
) )
......
...@@ -188,9 +188,9 @@ def _assess_hhsearch_hit( ...@@ -188,9 +188,9 @@ def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime, release_date_cutoff: datetime.datetime,
query_pdb_code: Optional[str] = None,
max_subsequence_ratio: float = 0.95, max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1, min_align_ratio: float = 0.1,
) -> bool: ) -> bool:
...@@ -752,12 +752,12 @@ class SingleHitResult: ...@@ -752,12 +752,12 @@ class SingleHitResult:
def _prefilter_hit( def _prefilter_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
strict_error_check: bool = False, strict_error_check: bool = False,
query_pdb_code: Optional[str] = None,
): ):
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
...@@ -794,7 +794,6 @@ def _prefilter_hit( ...@@ -794,7 +794,6 @@ def _prefilter_hit(
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
...@@ -803,6 +802,7 @@ def _process_single_hit( ...@@ -803,6 +802,7 @@ def _process_single_hit(
kalign_binary_path: str, kalign_binary_path: str,
strict_error_check: bool = False, strict_error_check: bool = False,
_zero_center_positions: bool = True, _zero_center_positions: bool = True,
query_pdb_code: Optional[str] = None,
) -> SingleHitResult: ) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit.""" """Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
...@@ -996,9 +996,9 @@ class TemplateHitFeaturizer: ...@@ -996,9 +996,9 @@ class TemplateHitFeaturizer:
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime], query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit], hits: Sequence[parsers.TemplateHit],
query_pdb_code: Optional[str] = None,
) -> TemplateSearchResult: ) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code) logging.info("Searching for template for: %s", query_pdb_code)
...@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
idx[:stk] = np.random.permutation(idx[:stk]) idx[:stk] = np.random.permutation(idx[:stk])
for i in idx: for i in idx:
if(len(already_seen) >= self._max_hits): if(len(already_seen) >= self.max_hits):
break break
hit = filtered[i] hit = filtered[i]
......
...@@ -23,6 +23,7 @@ import subprocess ...@@ -23,6 +23,7 @@ import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request from urllib import request
from fastfold.data import parsers
from fastfold.data.tools import utils from fastfold.data.tools import utils
...@@ -93,7 +94,10 @@ class Jackhmmer: ...@@ -93,7 +94,10 @@ class Jackhmmer:
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
def _query_chunk( def _query_chunk(
self, input_fasta_path: str, database_path: str self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer.""" """Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
...@@ -167,8 +171,11 @@ class Jackhmmer: ...@@ -167,8 +171,11 @@ class Jackhmmer:
with open(tblout_path) as f: with open(tblout_path) as f:
tbl = f.read() tbl = f.read()
if(max_sequences is None):
with open(sto_path) as f: with open(sto_path) as f:
sto = f.read() sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict( raw_output = dict(
sto=sto, sto=sto,
...@@ -180,10 +187,16 @@ class Jackhmmer: ...@@ -180,10 +187,16 @@ class Jackhmmer:
return raw_output return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer.""" """Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None: if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)] single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
return [single_chunk_result]
db_basename = os.path.basename(self.database_path) db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}" db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
...@@ -217,12 +230,20 @@ class Jackhmmer: ...@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk # Run Jackhmmer with the chunk
future.result() future.result()
chunked_output.append( chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)) self._query_chunk(
input_fasta_path,
db_local_chunk(i),
max_sequences
)
) )
# Remove the local copy of the chunk # Remove the local copy of the chunk
os.remove(db_local_chunk(i)) os.remove(db_local_chunk(i))
future = next_future future = next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback: if self.streaming_callback:
self.streaming_callback(i) self.streaming_callback(i)
return chunked_output return chunked_output
...@@ -19,6 +19,8 @@ import random ...@@ -19,6 +19,8 @@ import random
import sys import sys
import time import time
from datetime import date from datetime import date
import tempfile
import contextlib
import numpy as np import numpy as np
import torch import torch
...@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta ...@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
@contextlib.contextmanager
def temp_fasta_file(fasta_str: str):
with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
fasta_file.write(fasta_str)
fasta_file.seek(0)
yield fasta_file.name
def add_data_args(parser: argparse.ArgumentParser): def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
...@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser):
type=str, type=str,
default=None, default=None,
) )
parser.add_argument(
"--pdb_seqres_database_path",
type=str,
default=None,
)
parser.add_argument(
"--uniprot_database_path",
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer') parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits') parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch') parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign') parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument("--hmmsearch_binary_path", type=str, default="hmmsearch")
parser.add_argument("--hmmbuild_binary_path", type=str, default="hmmbuild")
parser.add_argument( parser.add_argument(
'--max_template_date', '--max_template_date',
type=str, type=str,
...@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
def inference_model(rank, world_size, result_q, batch, args): def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank) os.environ['LOCAL_RANK'] = str(rank)
...@@ -120,7 +141,7 @@ def main(args): ...@@ -120,7 +141,7 @@ def main(args):
def inference_multimer_model(args): def inference_multimer_model(args):
print("running in multimer mode...") print("running in multimer mode...")
config = model_config(args.model_name)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb")) # feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates = 4 predict_max_templates = 4
...@@ -143,6 +164,81 @@ def inference_multimer_model(args): ...@@ -143,6 +164,81 @@ def inference_multimer_model(args):
obsolete_pdbs_path=args.obsolete_pdbs_path, obsolete_pdbs_path=args.obsolete_pdbs_path,
) )
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunnerMultimer(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
monomer_data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=monomer_data_processor,
)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
fasta_path = args.fasta_path
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
chain_fasta_str = f'>chain_{tag}\n{seq}\n'
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
alignment_runner.run(
chain_fasta_path, local_alignment_dir
)
print(f"Finished running alignment for {tag}")
local_alignment_dir = alignment_dir
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True,
)
def inference_monomer_model(args): def inference_monomer_model(args):
print("running in monomer mode...") print("running in monomer mode...")
...@@ -282,6 +378,7 @@ def inference_monomer_model(args): ...@@ -282,6 +378,7 @@ def inference_monomer_model(args):
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment