Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
d3df8e69
Unverified
Commit
d3df8e69
authored
Sep 09, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Sep 09, 2022
Browse files
add multimer inference (#59)
* add multimer inference * add dada pipeline
parent
444c548a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
372 additions
and
24 deletions
+372
-24
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+242
-12
fastfold/data/templates.py
fastfold/data/templates.py
+5
-5
fastfold/data/tools/jackhmmer.py
fastfold/data/tools/jackhmmer.py
+27
-6
inference.py
inference.py
+98
-1
No files found.
fastfold/data/data_pipeline.py
View file @
d3df8e69
...
...
@@ -19,6 +19,7 @@ import contextlib
import
dataclasses
import
datetime
import
json
import
copy
from
multiprocessing
import
cpu_count
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
...
...
@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def
make_template_features
(
input_sequence
:
str
,
hits
:
Sequence
[
Any
],
template_featurizer
:
Any
,
template_featurizer
:
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
,
query_pdb_code
:
Optional
[
str
]
=
None
,
query_release_date
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
...
...
@@ -64,12 +65,18 @@ def make_template_features(
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
query_pdb_code
,
query_release_date
=
query_release_date
,
hits
=
hits_cat
,
)
if
type
(
template_featurizer
)
==
hhsearch
.
HHSearch
:
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
query_pdb_code
,
query_release_date
=
query_release_date
,
hits
=
hits_cat
,
)
else
:
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
hits
=
hits_cat
,
)
template_features
=
templates_result
.
features
# The template featurizer doesn't format empty template features
...
...
@@ -242,7 +249,7 @@ def run_msa_tool(
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
fasta_path
)
[
0
]
result
=
msa_runner
.
query
(
fasta_path
)
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result
[
msa_format
])
...
...
@@ -262,7 +269,6 @@ class AlignmentRunner:
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
pdb70_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
...
...
@@ -447,6 +453,225 @@ class AlignmentRunner:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
class
AlignmentRunnerMultimer
(
AlignmentRunner
):
"""Runs alignment tools and saves the results"""
def
__init__
(
self
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
# super().__init__()
db_map
=
{
"jackhmmer"
:
{
"binary"
:
jackhmmer_binary_path
,
"dbs"
:
[
uniref90_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
},
"hhblits"
:
{
"binary"
:
hhblits_binary_path
,
"dbs"
:
[
bfd_database_path
if
not
use_small_bfd
else
None
,
],
},
}
for
name
,
dic
in
db_map
.
items
():
binary
,
dbs
=
dic
[
"binary"
],
dic
[
"dbs"
]
if
(
binary
is
None
and
not
all
([
x
is
None
for
x
in
dbs
])):
raise
ValueError
(
f
"
{
name
}
DBs provided but
{
name
}
binary is None"
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniprot_max_hits
=
uniprot_max_hits
self
.
use_small_bfd
=
use_small_bfd
if
(
no_cpus
is
None
):
no_cpus
=
cpu_count
()
self
.
jackhmmer_uniref90_runner
=
None
if
(
jackhmmer_binary_path
is
not
None
and
uniref90_database_path
is
not
None
):
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_small_bfd_runner
=
None
self
.
hhblits_bfd_uniclust_runner
=
None
if
(
bfd_database_path
is
not
None
):
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
bfd_database_path
,
n_cpu
=
no_cpus
,
)
else
:
dbs
=
[
bfd_database_path
]
if
(
uniclust30_database_path
is
not
None
):
dbs
.
append
(
uniclust30_database_path
)
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
dbs
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_mgnify_runner
=
None
if
(
mgnify_database_path
is
not
None
):
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_uniprot_runner
=
None
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
if
(
template_searcher
is
not
None
and
self
.
jackhmmer_uniref90_runner
is
None
):
raise
ValueError
(
"Uniref90 runner must be specified to run template search"
)
self
.
template_searcher
=
template_searcher
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
):
"""Runs alignment tools on a sequence"""
if
(
self
.
jackhmmer_uniref90_runner
is
not
None
):
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.sto"
)
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniref90_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniref_max_hits
,
)
template_msa
=
jackhmmer_uniref90_result
[
"sto"
]
template_msa
=
parsers
.
deduplicate_stockholm_msa
(
template_msa
)
template_msa
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
template_msa
)
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
,
output_dir
=
output_dir
)
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
template_msa
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
,
output_dir
=
output_dir
)
else
:
fmt
=
self
.
template_searcher
.
input_format
raise
ValueError
(
f
"Unrecognized template input format:
{
fmt
}
"
)
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.sto"
)
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
mgnify_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
mgnify_max_hits
)
if
(
self
.
use_small_bfd
and
self
.
jackhmmer_small_bfd_runner
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
jackhmmer_small_bfd_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"sto"
,
)
elif
(
self
.
hhblits_bfd_uniclust_runner
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
hhblits_bfd_uniclust_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uniclust_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"a3m"
,
)
if
(
self
.
jackhmmer_uniprot_runner
is
not
None
):
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
)
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
...
...
@@ -722,7 +947,12 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
input_sequence
,
_alignment_index
,
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -893,8 +1123,8 @@ class DataPipelineMultimer:
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
(
[
msa
]
)
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
(
msa
,
deletion_matrix
)
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
)
...
...
fastfold/data/templates.py
View file @
d3df8e69
...
...
@@ -188,9 +188,9 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
query_pdb_code
:
Optional
[
str
]
=
None
,
max_subsequence_ratio
:
float
=
0.95
,
min_align_ratio
:
float
=
0.1
,
)
->
bool
:
...
...
@@ -752,12 +752,12 @@ class SingleHitResult:
def
_prefilter_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
strict_error_check
:
bool
=
False
,
query_pdb_code
:
Optional
[
str
]
=
None
,
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
...
...
@@ -794,7 +794,6 @@ def _prefilter_hit(
def
_process_single_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
...
...
@@ -803,6 +802,7 @@ def _process_single_hit(
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
_zero_center_positions
:
bool
=
True
,
query_pdb_code
:
Optional
[
str
]
=
None
,
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
...
...
@@ -996,9 +996,9 @@ class TemplateHitFeaturizer:
def
get_templates
(
self
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
query_pdb_code
:
Optional
[
str
]
=
None
,
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_pdb_code
)
...
...
@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_
max_hits
):
if
(
len
(
already_seen
)
>=
self
.
max_hits
):
break
hit
=
filtered
[
i
]
...
...
fastfold/data/tools/jackhmmer.py
View file @
d3df8e69
...
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
fastfold.data
import
parsers
from
fastfold.data.tools
import
utils
...
...
@@ -93,7 +94,10 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
...
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
if
(
max_sequences
is
None
):
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
sto
=
sto
,
...
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
inference.py
View file @
d3df8e69
...
...
@@ -19,6 +19,8 @@ import random
import
sys
import
time
from
datetime
import
date
import
tempfile
import
contextlib
import
numpy
as
np
import
torch
...
...
@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.tensor_utils
import
tensor_tree_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
...
...
@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser):
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--pdb_seqres_database_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--uniprot_database_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
"--hmmsearch_binary_path"
,
type
=
str
,
default
=
"hmmsearch"
)
parser
.
add_argument
(
"--hmmbuild_binary_path"
,
type
=
str
,
default
=
"hmmbuild"
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
...
...
@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--enable_workflow'
,
default
=
False
,
action
=
'store_true'
,
help
=
'run inference with ray workflow or not'
)
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
os
.
environ
[
'RANK'
]
=
str
(
rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
rank
)
...
...
@@ -120,7 +141,7 @@ def main(args):
def
inference_multimer_model
(
args
):
print
(
"running in multimer mode..."
)
config
=
model_config
(
args
.
model_name
)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates
=
4
...
...
@@ -143,6 +164,81 @@ def inference_multimer_model(args):
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
,
)
if
(
not
args
.
use_precomputed_alignments
):
alignment_runner
=
data_pipeline
.
AlignmentRunnerMultimer
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
no_cpus
=
args
.
cpus
,
)
else
:
alignment_runner
=
None
monomer_data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
)
data_processor
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
monomer_data_processor
,
)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
if
(
not
args
.
use_precomputed_alignments
):
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
alignment_dir
=
args
.
use_precomputed_alignments
# Gather input sequences
fasta_path
=
args
.
fasta_path
with
open
(
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
chain_fasta_str
=
f
'>chain_
{
tag
}
\n
{
seq
}
\n
'
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
alignment_runner
.
run
(
chain_fasta_path
,
local_alignment_dir
)
print
(
f
"Finished running alignment for
{
tag
}
"
)
local_alignment_dir
=
alignment_dir
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
True
,
)
def
inference_monomer_model
(
args
):
print
(
"running in monomer mode..."
)
...
...
@@ -282,6 +378,7 @@ def inference_monomer_model(args):
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment