Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
d3df8e69
Unverified
Commit
d3df8e69
authored
Sep 09, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Sep 09, 2022
Browse files
add multimer inference (#59)
* add multimer inference * add dada pipeline
parent
444c548a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
372 additions
and
24 deletions
+372
-24
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+242
-12
fastfold/data/templates.py
fastfold/data/templates.py
+5
-5
fastfold/data/tools/jackhmmer.py
fastfold/data/tools/jackhmmer.py
+27
-6
inference.py
inference.py
+98
-1
No files found.
fastfold/data/data_pipeline.py
View file @
d3df8e69
...
@@ -19,6 +19,7 @@ import contextlib
...
@@ -19,6 +19,7 @@ import contextlib
import
dataclasses
import
dataclasses
import
datetime
import
datetime
import
json
import
json
import
copy
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
import
tempfile
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
...
@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict:
...
@@ -56,7 +57,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def
make_template_features
(
def
make_template_features
(
input_sequence
:
str
,
input_sequence
:
str
,
hits
:
Sequence
[
Any
],
hits
:
Sequence
[
Any
],
template_featurizer
:
Any
,
template_featurizer
:
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
,
query_pdb_code
:
Optional
[
str
]
=
None
,
query_pdb_code
:
Optional
[
str
]
=
None
,
query_release_date
:
Optional
[
str
]
=
None
,
query_release_date
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
...
@@ -64,12 +65,18 @@ def make_template_features(
...
@@ -64,12 +65,18 @@ def make_template_features(
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
template_features
=
empty_template_feats
(
len
(
input_sequence
))
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
else
:
if
type
(
template_featurizer
)
==
hhsearch
.
HHSearch
:
templates_result
=
template_featurizer
.
get_templates
(
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
query_pdb_code
,
query_pdb_code
=
query_pdb_code
,
query_release_date
=
query_release_date
,
query_release_date
=
query_release_date
,
hits
=
hits_cat
,
hits
=
hits_cat
,
)
)
else
:
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
hits
=
hits_cat
,
)
template_features
=
templates_result
.
features
template_features
=
templates_result
.
features
# The template featurizer doesn't format empty template features
# The template featurizer doesn't format empty template features
...
@@ -242,7 +249,7 @@ def run_msa_tool(
...
@@ -242,7 +249,7 @@ def run_msa_tool(
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
else
:
result
=
msa_runner
.
query
(
fasta_path
)
[
0
]
result
=
msa_runner
.
query
(
fasta_path
)
with
open
(
msa_out_path
,
"w"
)
as
f
:
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result
[
msa_format
])
f
.
write
(
result
[
msa_format
])
...
@@ -262,7 +269,6 @@ class AlignmentRunner:
...
@@ -262,7 +269,6 @@ class AlignmentRunner:
bfd_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
pdb70_database_path
:
Optional
[
str
]
=
None
,
pdb70_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
uniref_max_hits
:
int
=
10000
,
...
@@ -447,6 +453,225 @@ class AlignmentRunner:
...
@@ -447,6 +453,225 @@ class AlignmentRunner:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
class
AlignmentRunnerMultimer
(
AlignmentRunner
):
"""Runs alignment tools and saves the results"""
def
__init__
(
self
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
"""
Args:
jackhmmer_binary_path:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
mgnify_database_path:
Path to mgnify database. If provided, jackhmmer_binary_path
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
uniref_max_hits:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
"""
# super().__init__()
db_map
=
{
"jackhmmer"
:
{
"binary"
:
jackhmmer_binary_path
,
"dbs"
:
[
uniref90_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
},
"hhblits"
:
{
"binary"
:
hhblits_binary_path
,
"dbs"
:
[
bfd_database_path
if
not
use_small_bfd
else
None
,
],
},
}
for
name
,
dic
in
db_map
.
items
():
binary
,
dbs
=
dic
[
"binary"
],
dic
[
"dbs"
]
if
(
binary
is
None
and
not
all
([
x
is
None
for
x
in
dbs
])):
raise
ValueError
(
f
"
{
name
}
DBs provided but
{
name
}
binary is None"
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniprot_max_hits
=
uniprot_max_hits
self
.
use_small_bfd
=
use_small_bfd
if
(
no_cpus
is
None
):
no_cpus
=
cpu_count
()
self
.
jackhmmer_uniref90_runner
=
None
if
(
jackhmmer_binary_path
is
not
None
and
uniref90_database_path
is
not
None
):
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_small_bfd_runner
=
None
self
.
hhblits_bfd_uniclust_runner
=
None
if
(
bfd_database_path
is
not
None
):
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
bfd_database_path
,
n_cpu
=
no_cpus
,
)
else
:
dbs
=
[
bfd_database_path
]
if
(
uniclust30_database_path
is
not
None
):
dbs
.
append
(
uniclust30_database_path
)
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
dbs
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_mgnify_runner
=
None
if
(
mgnify_database_path
is
not
None
):
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_uniprot_runner
=
None
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
if
(
template_searcher
is
not
None
and
self
.
jackhmmer_uniref90_runner
is
None
):
raise
ValueError
(
"Uniref90 runner must be specified to run template search"
)
self
.
template_searcher
=
template_searcher
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
):
"""Runs alignment tools on a sequence"""
if
(
self
.
jackhmmer_uniref90_runner
is
not
None
):
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.sto"
)
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniref90_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniref_max_hits
,
)
template_msa
=
jackhmmer_uniref90_result
[
"sto"
]
template_msa
=
parsers
.
deduplicate_stockholm_msa
(
template_msa
)
template_msa
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
template_msa
)
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
,
output_dir
=
output_dir
)
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
template_msa
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
,
output_dir
=
output_dir
)
else
:
fmt
=
self
.
template_searcher
.
input_format
raise
ValueError
(
f
"Unrecognized template input format:
{
fmt
}
"
)
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.sto"
)
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
mgnify_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
mgnify_max_hits
)
if
(
self
.
use_small_bfd
and
self
.
jackhmmer_small_bfd_runner
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
jackhmmer_small_bfd_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"sto"
,
)
elif
(
self
.
hhblits_bfd_uniclust_runner
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
hhblits_bfd_uniclust_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uniclust_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"a3m"
,
)
if
(
self
.
jackhmmer_uniprot_runner
is
not
None
):
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
...
@@ -722,7 +947,12 @@ class DataPipeline:
...
@@ -722,7 +947,12 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
input_sequence
,
_alignment_index
,
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -893,8 +1123,8 @@ class DataPipelineMultimer:
...
@@ -893,8 +1123,8 @@ class DataPipelineMultimer:
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
(
[
msa
]
)
all_seq_features
=
make_msa_features
(
msa
,
deletion_matrix
)
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
'msa_species_identifiers'
,
)
)
...
...
fastfold/data/templates.py
View file @
d3df8e69
...
@@ -188,9 +188,9 @@ def _assess_hhsearch_hit(
...
@@ -188,9 +188,9 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
release_date_cutoff
:
datetime
.
datetime
,
query_pdb_code
:
Optional
[
str
]
=
None
,
max_subsequence_ratio
:
float
=
0.95
,
max_subsequence_ratio
:
float
=
0.95
,
min_align_ratio
:
float
=
0.1
,
min_align_ratio
:
float
=
0.1
,
)
->
bool
:
)
->
bool
:
...
@@ -752,12 +752,12 @@ class SingleHitResult:
...
@@ -752,12 +752,12 @@ class SingleHitResult:
def
_prefilter_hit
(
def
_prefilter_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
query_pdb_code
:
Optional
[
str
]
=
None
,
):
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
...
@@ -794,7 +794,6 @@ def _prefilter_hit(
...
@@ -794,7 +794,6 @@ def _prefilter_hit(
def
_process_single_hit
(
def
_process_single_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
...
@@ -803,6 +802,7 @@ def _process_single_hit(
...
@@ -803,6 +802,7 @@ def _process_single_hit(
kalign_binary_path
:
str
,
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
_zero_center_positions
:
bool
=
True
,
_zero_center_positions
:
bool
=
True
,
query_pdb_code
:
Optional
[
str
]
=
None
,
)
->
SingleHitResult
:
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
# Fail hard if we can't get the PDB ID and chain name from the hit.
...
@@ -996,9 +996,9 @@ class TemplateHitFeaturizer:
...
@@ -996,9 +996,9 @@ class TemplateHitFeaturizer:
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
query_pdb_code
:
Optional
[
str
]
=
None
,
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_pdb_code
)
logging
.
info
(
"Searching for template for: %s"
,
query_pdb_code
)
...
@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1155,7 +1155,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_
max_hits
):
if
(
len
(
already_seen
)
>=
self
.
max_hits
):
break
break
hit
=
filtered
[
i
]
hit
=
filtered
[
i
]
...
...
fastfold/data/tools/jackhmmer.py
View file @
d3df8e69
...
@@ -23,6 +23,7 @@ import subprocess
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
urllib
import
request
from
fastfold.data
import
parsers
from
fastfold.data.tools
import
utils
from
fastfold.data.tools
import
utils
...
@@ -93,7 +94,10 @@ class Jackhmmer:
...
@@ -93,7 +94,10 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
...
@@ -167,8 +171,11 @@ class Jackhmmer:
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
tbl
=
f
.
read
()
if
(
max_sequences
is
None
):
with
open
(
sto_path
)
as
f
:
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
raw_output
=
dict
(
sto
=
sto
,
sto
=
sto
,
...
@@ -180,10 +187,16 @@ class Jackhmmer:
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
@@ -217,12 +230,20 @@ class Jackhmmer:
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
# Run Jackhmmer with the chunk
future
.
result
()
future
.
result
()
chunked_output
.
append
(
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
)
# Remove the local copy of the chunk
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
self
.
streaming_callback
(
i
)
return
chunked_output
return
chunked_output
inference.py
View file @
d3df8e69
...
@@ -19,6 +19,8 @@ import random
...
@@ -19,6 +19,8 @@ import random
import
sys
import
sys
import
time
import
time
from
datetime
import
date
from
datetime
import
date
import
tempfile
import
contextlib
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta
...
@@ -39,6 +41,12 @@ from fastfold.data.parsers import parse_fasta
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.tensor_utils
import
tensor_tree_map
from
fastfold.utils.tensor_utils
import
tensor_tree_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -66,10 +74,22 @@ def add_data_args(parser: argparse.ArgumentParser):
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
)
)
parser
.
add_argument
(
"--pdb_seqres_database_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--uniprot_database_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
"--hmmsearch_binary_path"
,
type
=
str
,
default
=
"hmmsearch"
)
parser
.
add_argument
(
"--hmmbuild_binary_path"
,
type
=
str
,
default
=
"hmmbuild"
)
parser
.
add_argument
(
parser
.
add_argument
(
'--max_template_date'
,
'--max_template_date'
,
type
=
str
,
type
=
str
,
...
@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -79,6 +99,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--enable_workflow'
,
default
=
False
,
action
=
'store_true'
,
help
=
'run inference with ray workflow or not'
)
parser
.
add_argument
(
'--enable_workflow'
,
default
=
False
,
action
=
'store_true'
,
help
=
'run inference with ray workflow or not'
)
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
os
.
environ
[
'RANK'
]
=
str
(
rank
)
os
.
environ
[
'RANK'
]
=
str
(
rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
rank
)
...
@@ -120,7 +141,7 @@ def main(args):
...
@@ -120,7 +141,7 @@ def main(args):
def
inference_multimer_model
(
args
):
def
inference_multimer_model
(
args
):
print
(
"running in multimer mode..."
)
print
(
"running in multimer mode..."
)
config
=
model_config
(
args
.
model_name
)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates
=
4
predict_max_templates
=
4
...
@@ -143,6 +164,81 @@ def inference_multimer_model(args):
...
@@ -143,6 +164,81 @@ def inference_multimer_model(args):
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
,
)
)
if
(
not
args
.
use_precomputed_alignments
):
alignment_runner
=
data_pipeline
.
AlignmentRunnerMultimer
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
no_cpus
=
args
.
cpus
,
)
else
:
alignment_runner
=
None
monomer_data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
)
data_processor
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
monomer_data_processor
,
)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
if
(
not
args
.
use_precomputed_alignments
):
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
alignment_dir
=
args
.
use_precomputed_alignments
# Gather input sequences
fasta_path
=
args
.
fasta_path
with
open
(
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
chain_fasta_str
=
f
'>chain_
{
tag
}
\n
{
seq
}
\n
'
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
alignment_runner
.
run
(
chain_fasta_path
,
local_alignment_dir
)
print
(
f
"Finished running alignment for
{
tag
}
"
)
local_alignment_dir
=
alignment_dir
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
True
,
)
def
inference_monomer_model
(
args
):
def
inference_monomer_model
(
args
):
print
(
"running in monomer mode..."
)
print
(
"running in monomer mode..."
)
...
@@ -282,6 +378,7 @@ def inference_monomer_model(args):
...
@@ -282,6 +378,7 @@ def inference_monomer_model(args):
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment