Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
6835c248
Unverified
Commit
6835c248
authored
Sep 21, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Sep 21, 2022
Browse files
add multimer workflow (#70)
* add multimer workflow * support multimer dataworkflow
parent
9ab281fe
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
335 additions
and
77 deletions
+335
-77
environment.yml
environment.yml
+2
-0
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+20
-29
fastfold/data/tools/hhsearch.py
fastfold/data/tools/hhsearch.py
+8
-0
fastfold/data/tools/hmmsearch.py
fastfold/data/tools/hmmsearch.py
+3
-1
fastfold/workflow/factory/__init__.py
fastfold/workflow/factory/__init__.py
+2
-1
fastfold/workflow/factory/hmmsearch.py
fastfold/workflow/factory/hmmsearch.py
+41
-0
fastfold/workflow/factory/jackhmmer.py
fastfold/workflow/factory/jackhmmer.py
+14
-8
fastfold/workflow/template/__init__.py
fastfold/workflow/template/__init__.py
+2
-1
fastfold/workflow/template/fastfold_data_workflow.py
fastfold/workflow/template/fastfold_data_workflow.py
+5
-11
fastfold/workflow/template/fastfold_multimer_data_workflow.py
...fold/workflow/template/fastfold_multimer_data_workflow.py
+192
-0
inference.py
inference.py
+46
-26
No files found.
environment.yml
View file @
6835c248
...
@@ -16,6 +16,8 @@ dependencies:
...
@@ -16,6 +16,8 @@ dependencies:
-
typing-extensions==3.10.0.2
-
typing-extensions==3.10.0.2
-
einops
-
einops
-
colossalai
-
colossalai
-
ray==2.0.0
-
pyarrow
-
pandas
-
pandas
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
...
...
fastfold/data/data_pipeline.py
View file @
6835c248
...
@@ -42,7 +42,6 @@ from fastfold.common import residue_constants, protein
...
@@ -42,7 +42,6 @@ from fastfold.common import residue_constants, protein
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
...
@@ -466,12 +465,14 @@ class AlignmentRunnerMultimer:
...
@@ -466,12 +465,14 @@ class AlignmentRunnerMultimer:
self
,
self
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
hmmsearch_binary_path
:
Optional
[
str
]
=
None
,
hmmbuild_binary_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearche
r
]
=
None
,
pdb_seqres_database_path
:
Optional
[
st
r
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
uniref_max_hits
:
int
=
10000
,
...
@@ -524,6 +525,12 @@ class AlignmentRunnerMultimer:
...
@@ -524,6 +525,12 @@ class AlignmentRunnerMultimer:
bfd_database_path
if
not
use_small_bfd
else
None
,
bfd_database_path
if
not
use_small_bfd
else
None
,
],
],
},
},
"hmmsearch"
:
{
"binary"
:
hmmsearch_binary_path
,
"dbs"
:
[
pdb_seqres_database_path
,
],
},
}
}
for
name
,
dic
in
db_map
.
items
():
for
name
,
dic
in
db_map
.
items
():
...
@@ -585,15 +592,14 @@ class AlignmentRunnerMultimer:
...
@@ -585,15 +592,14 @@ class AlignmentRunnerMultimer:
database_path
=
uniprot_database_path
database_path
=
uniprot_database_path
)
)
if
(
template_searcher
is
not
None
and
self
.
hmmsearch_pdb_runner
=
None
self
.
jackhmmer_uniref90_runner
is
None
if
(
pdb_seqres_database_path
is
not
None
):
):
self
.
hmmsearch_pdb_runner
=
hmmsearch
.
Hmmsearch
(
raise
ValueError
(
binary_path
=
hmmsearch_binary_path
,
"Uniref90 runner must be specified to run template search"
hmmbuild_binary_path
=
hmmbuild_binary_path
,
database_path
=
pdb_seqres_database_path
,
)
)
self
.
template_searcher
=
template_searcher
def
run
(
def
run
(
self
,
self
,
fasta_path
:
str
,
fasta_path
:
str
,
...
@@ -617,25 +623,11 @@ class AlignmentRunnerMultimer:
...
@@ -617,25 +623,11 @@ class AlignmentRunnerMultimer:
template_msa
template_msa
)
)
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
hmmsearch_pdb_runner
is
not
None
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
pdb_templates_result
=
self
.
hmmsearch_pdb_runner
.
query
(
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
,
template_msa
,
output_dir
=
output_dir
output_dir
=
output_dir
)
)
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
template_msa
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
,
output_dir
=
output_dir
)
else
:
fmt
=
self
.
template_searcher
.
input_format
raise
ValueError
(
f
"Unrecognized template input format:
{
fmt
}
"
)
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.sto"
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.sto"
)
...
@@ -835,7 +827,6 @@ class DataPipeline:
...
@@ -835,7 +827,6 @@ class DataPipeline:
for
f
in
os
.
listdir
(
alignment_dir
):
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
filename
,
ext
=
os
.
path
.
splitext
(
f
)
filename
,
ext
=
os
.
path
.
splitext
(
f
)
if
(
ext
==
".a3m"
):
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
...
...
fastfold/data/tools/hhsearch.py
View file @
6835c248
...
@@ -87,6 +87,14 @@ class HHSearch:
...
@@ -87,6 +87,14 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
f
"Could not find HHsearch database
{
database_path
}
"
)
)
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
,
gen_atab
:
bool
=
False
)
->
Union
[
str
,
tuple
]:
def
query
(
self
,
a3m
:
str
,
gen_atab
:
bool
=
False
)
->
Union
[
str
,
tuple
]:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
...
...
fastfold/data/tools/hmmsearch.py
View file @
6835c248
...
@@ -32,6 +32,7 @@ class Hmmsearch(object):
...
@@ -32,6 +32,7 @@ class Hmmsearch(object):
binary_path
:
str
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
database_path
:
str
,
n_cpu
:
int
=
8
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
):
"""Initializes the Python hmmsearch wrapper.
"""Initializes the Python hmmsearch wrapper.
...
@@ -49,6 +50,7 @@ class Hmmsearch(object):
...
@@ -49,6 +50,7 @@ class Hmmsearch(object):
self
.
binary_path
=
binary_path
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
self
.
database_path
=
database_path
self
.
n_cpu
=
n_cpu
if
flags
is
None
:
if
flags
is
None
:
# Default hmmsearch run settings.
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
flags
=
[
'--F1'
,
'0.1'
,
...
@@ -95,7 +97,7 @@ class Hmmsearch(object):
...
@@ -95,7 +97,7 @@ class Hmmsearch(object):
cmd
=
[
cmd
=
[
self
.
binary_path
,
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
'--cpu'
,
str
(
self
.
n_cpu
)
]
]
# If adding flags, we have to do so before the output and input:
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
if
self
.
flags
:
...
...
fastfold/workflow/factory/__init__.py
View file @
6835c248
...
@@ -3,3 +3,4 @@ from .hhblits import HHBlitsFactory
...
@@ -3,3 +3,4 @@ from .hhblits import HHBlitsFactory
from
.hhsearch
import
HHSearchFactory
from
.hhsearch
import
HHSearchFactory
from
.jackhmmer
import
JackHmmerFactory
from
.jackhmmer
import
JackHmmerFactory
from
.hhfilter
import
HHfilterFactory
from
.hhfilter
import
HHfilterFactory
from
.hmmsearch
import
HmmSearchFactory
\ No newline at end of file
fastfold/workflow/factory/hmmsearch.py
0 → 100644
View file @
6835c248
from
typing
import
List
import
inspect
import
ray
from
ray.dag.function_node
import
FunctionNode
from
fastfold.data.tools
import
hmmsearch
,
hmmbuild
from
fastfold.data
import
parsers
from
fastfold.workflow.factory
import
TaskFactory
from
typing
import
Optional
class
HmmSearchFactory
(
TaskFactory
):
keywords
=
[
'binary_path'
,
'hmmbuild_binary_path'
,
'database_path'
,
'n_cpu'
]
def
gen_node
(
self
,
msa_sto_path
:
str
,
output_dir
:
Optional
[
str
]
=
None
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
self
.
isReady
()
params
=
{
k
:
self
.
config
.
get
(
k
)
for
k
in
inspect
.
getfullargspec
(
hmmsearch
.
Hmmsearch
.
__init__
).
kwonlyargs
if
self
.
config
.
get
(
k
)
}
# setup runner with a filtered config dict
runner
=
hmmsearch
.
Hmmsearch
(
**
params
)
# generate function node
@
ray
.
remote
def
hmmsearch_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
with
open
(
msa_sto_path
,
"r"
)
as
f
:
msa_sto
=
f
.
read
()
msa_sto
=
parsers
.
deduplicate_stockholm_msa
(
msa_sto
)
msa_sto
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
msa_sto
)
hmmsearch_result
=
runner
.
query
(
msa_sto
,
output_dir
=
output_dir
)
return
hmmsearch_node_func
.
bind
(
after
)
fastfold/workflow/factory/jackhmmer.py
View file @
6835c248
...
@@ -13,7 +13,7 @@ class JackHmmerFactory(TaskFactory):
...
@@ -13,7 +13,7 @@ class JackHmmerFactory(TaskFactory):
keywords
=
[
'binary_path'
,
'database_path'
,
'n_cpu'
,
'uniref_max_hits'
]
keywords
=
[
'binary_path'
,
'database_path'
,
'n_cpu'
,
'uniref_max_hits'
]
def
gen_node
(
self
,
fasta_path
:
str
,
output_path
:
str
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
def
gen_node
(
self
,
fasta_path
:
str
,
output_path
:
str
,
after
:
List
[
FunctionNode
]
=
None
,
output_format
:
str
=
"a3m"
)
->
FunctionNode
:
self
.
isReady
()
self
.
isReady
()
...
@@ -28,11 +28,17 @@ class JackHmmerFactory(TaskFactory):
...
@@ -28,11 +28,17 @@ class JackHmmerFactory(TaskFactory):
@
ray
.
remote
@
ray
.
remote
def
jackhmmer_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
def
jackhmmer_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
result
=
runner
.
query
(
fasta_path
)[
0
]
result
=
runner
.
query
(
fasta_path
)[
0
]
if
output_format
==
"a3m"
:
uniref90_msa_a3m
=
parsers
.
convert_stockholm_to_a3m
(
uniref90_msa_a3m
=
parsers
.
convert_stockholm_to_a3m
(
result
[
'sto'
],
result
[
'sto'
],
max_sequences
=
self
.
config
[
'uniref_max_hits'
]
max_sequences
=
self
.
config
[
'uniref_max_hits'
]
)
)
with
open
(
output_path
,
"w"
)
as
f
:
with
open
(
output_path
,
"w"
)
as
f
:
f
.
write
(
uniref90_msa_a3m
)
f
.
write
(
uniref90_msa_a3m
)
elif
output_format
==
"sto"
:
template_msa
=
result
[
'sto'
]
with
open
(
output_path
,
"w"
)
as
f
:
f
.
write
(
template_msa
)
return
jackhmmer_node_func
.
bind
(
after
)
return
jackhmmer_node_func
.
bind
(
after
)
fastfold/workflow/template/__init__.py
View file @
6835c248
from
.fastfold_data_workflow
import
FastFoldDataWorkFlow
from
.fastfold_data_workflow
import
FastFoldDataWorkFlow
from
.fastfold_multimer_data_workflow
import
FastFoldMultimerDataWorkFlow
\ No newline at end of file
fastfold/workflow/template/fastfold_data_workflow.py
View file @
6835c248
...
@@ -118,11 +118,12 @@ class FastFoldDataWorkFlow:
...
@@ -118,11 +118,12 @@ class FastFoldDataWorkFlow:
self
.
jackhmmer_small_bfd_factory
=
JackHmmerFactory
(
config
=
jh_config
)
self
.
jackhmmer_small_bfd_factory
=
JackHmmerFactory
(
config
=
jh_config
)
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
#
storage_dir = "file:///tmp/ray/lcmql/workflow_data"
storage_dir
=
"file:///tmp/ray/lcmql/workflow_data"
if
storage_dir
is
not
None
:
if
storage_dir
is
not
None
:
if
not
os
.
path
.
exists
(
storage_dir
):
if
not
os
.
path
.
exists
(
storage_dir
):
os
.
makedirs
(
storage_dir
)
os
.
makedirs
(
storage_dir
)
if
not
ray
.
is_initialized
():
ray
.
init
(
storage
=
storage_dir
)
ray
.
init
(
storage
=
storage_dir
)
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
...
@@ -135,13 +136,6 @@ class FastFoldDataWorkFlow:
...
@@ -135,13 +136,6 @@ class FastFoldDataWorkFlow:
print
(
"Workflow not found. Clean. Skipping"
)
print
(
"Workflow not found. Clean. Skipping"
)
pass
pass
# prepare alignment directory for alignment outputs
if
alignment_dir
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir
,
"alignment"
)
if
not
os
.
path
.
exists
(
alignment_dir
):
os
.
makedirs
(
alignment_dir
)
# Run JackHmmer on UNIREF90
# Run JackHmmer on UNIREF90
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
# generate the workflow with i/o path
# generate the workflow with i/o path
...
@@ -167,7 +161,7 @@ class FastFoldDataWorkFlow:
...
@@ -167,7 +161,7 @@ class FastFoldDataWorkFlow:
# Run Jackhmmer on small_bfd
# Run Jackhmmer on small_bfd
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
# generate workflow for STEP4_2
# generate workflow for STEP4_2
bfd_node
=
self
.
jackhmmer_small_bfd_factory
.
gen_node
(
fasta_path
,
bfd_out_path
)
bfd_node
=
self
.
jackhmmer_small_bfd_factory
.
gen_node
(
fasta_path
,
bfd_out_path
,
output_format
=
"sto"
)
# run workflow
# run workflow
batch_run
(
workflow_id
=
workflow_id
,
dags
=
[
hhs_node
,
mgnify_node
,
bfd_node
])
batch_run
(
workflow_id
=
workflow_id
,
dags
=
[
hhs_node
,
mgnify_node
,
bfd_node
])
...
...
fastfold/workflow/template/fastfold_multimer_data_workflow.py
0 → 100644
View file @
6835c248
import
os
import
time
from
multiprocessing
import
cpu_count
import
ray
from
ray
import
workflow
from
fastfold.data.tools
import
hmmsearch
from
fastfold.workflow.factory
import
JackHmmerFactory
,
HHBlitsFactory
,
HmmSearchFactory
from
fastfold.workflow
import
batch_run
from
typing
import
Optional
,
Union
class
FastFoldMultimerDataWorkFlow
:
def
__init__
(
self
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
hmmsearch_binary_path
:
Optional
[
str
]
=
None
,
hmmbuild_binary_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
pdb_seqres_database_path
:
Optional
[
str
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
db_map
=
{
"jackhmmer"
:
{
"binary"
:
jackhmmer_binary_path
,
"dbs"
:
[
uniref90_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
},
"hhblits"
:
{
"binary"
:
hhblits_binary_path
,
"dbs"
:
[
bfd_database_path
if
not
use_small_bfd
else
None
,
],
},
"hmmsearch"
:
{
"binary"
:
hmmsearch_binary_path
,
"dbs"
:
[
pdb_seqres_database_path
,
],
},
}
for
name
,
dic
in
db_map
.
items
():
binary
,
dbs
=
dic
[
"binary"
],
dic
[
"dbs"
]
if
(
binary
is
None
and
not
all
([
x
is
None
for
x
in
dbs
])):
raise
ValueError
(
f
"
{
name
}
DBs provided but
{
name
}
binary is None"
)
if
(
not
all
([
x
is
None
for
x
in
db_map
[
"hmmsearch"
][
"dbs"
]])
and
uniref90_database_path
is
None
):
raise
ValueError
(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self
.
use_small_bfd
=
use_small_bfd
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
if
(
no_cpus
is
None
):
self
.
no_cpus
=
cpu_count
()
else
:
self
.
no_cpus
=
no_cpus
# create JackHmmer workflow generator
self
.
jackhmmer_uniref90_factory
=
None
if
jackhmmer_binary_path
is
not
None
and
uniref90_database_path
is
not
None
:
jh_config
=
{
"binary_path"
:
db_map
[
"jackhmmer"
][
"binary"
],
"database_path"
:
uniref90_database_path
,
"n_cpu"
:
no_cpus
,
"uniref_max_hits"
:
uniref_max_hits
,
}
self
.
jackhmmer_uniref90_factory
=
JackHmmerFactory
(
config
=
jh_config
)
# create HMMSearch workflow generator
self
.
hmmsearch_pdb_factory
=
None
if
pdb_seqres_database_path
is
not
None
:
hmm_config
=
{
"binary_path"
:
db_map
[
"hmmsearch"
][
"binary"
],
"hmmbuild_binary_path"
:
hmmbuild_binary_path
,
"database_path"
:
pdb_seqres_database_path
,
"n_cpu"
:
self
.
no_cpus
,
}
self
.
hmmsearch_pdb_factory
=
HmmSearchFactory
(
config
=
hmm_config
)
self
.
jackhmmer_mgnify_factory
=
None
if
jackhmmer_binary_path
is
not
None
and
mgnify_database_path
is
not
None
:
jh_config
=
{
"binary_path"
:
db_map
[
"jackhmmer"
][
"binary"
],
"database_path"
:
mgnify_database_path
,
"n_cpu"
:
no_cpus
,
"uniref_max_hits"
:
mgnify_max_hits
,
}
self
.
jackhmmer_mgnify_factory
=
JackHmmerFactory
(
config
=
jh_config
)
if
bfd_database_path
is
not
None
:
if
not
use_small_bfd
:
hhb_config
=
{
"binary_path"
:
db_map
[
"hhblits"
][
"binary"
],
"databases"
:
db_map
[
"hhblits"
][
"dbs"
],
"n_cpu"
:
self
.
no_cpus
,
}
self
.
hhblits_bfd_factory
=
HHBlitsFactory
(
config
=
hhb_config
)
else
:
jh_config
=
{
"binary_path"
:
db_map
[
"jackhmmer"
][
"binary"
],
"database_path"
:
bfd_database_path
,
"n_cpu"
:
no_cpus
,
}
self
.
jackhmmer_small_bfd_factory
=
JackHmmerFactory
(
config
=
jh_config
)
self
.
jackhmmer_uniprot_factory
=
None
if
jackhmmer_binary_path
is
not
None
and
uniprot_database_path
is
not
None
:
jh_config
=
{
"binary_path"
:
db_map
[
"jackhmmer"
][
"binary"
],
"database_path"
:
uniprot_database_path
,
"n_cpu"
:
no_cpus
,
"uniref_max_hits"
:
uniprot_max_hits
,
}
self
.
jackhmmer_uniprot_factory
=
JackHmmerFactory
(
config
=
jh_config
)
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
storage_dir
=
"file:///tmp/ray/lcmql/workflow_data"
if
storage_dir
is
not
None
:
if
not
os
.
path
.
exists
(
storage_dir
):
os
.
makedirs
(
storage_dir
)
if
not
ray
.
is_initialized
():
ray
.
init
(
storage
=
storage_dir
)
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
workflow_id
=
'fastfold_data_workflow '
+
str
(
localtime
)
# clearing remaining ray workflow data
try
:
workflow
.
cancel
(
workflow_id
)
workflow
.
delete
(
workflow_id
)
except
:
print
(
"Workflow not found. Clean. Skipping"
)
pass
# Run JackHmmer on UNIREF90
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.sto"
)
# generate the workflow with i/o path
uniref90_node
=
self
.
jackhmmer_uniref90_factory
.
gen_node
(
fasta_path
,
uniref90_out_path
,
output_format
=
"sto"
)
#Run HmmSearch on STEP1's result with PDB"""
# generate the workflow (STEP2 depend on STEP1)
hmm_node
=
self
.
hmmsearch_pdb_factory
.
gen_node
(
uniref90_out_path
,
output_dir
=
alignment_dir
,
after
=
[
uniref90_node
])
# Run JackHmmer on MGNIFY
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
"mgnify_hits.sto"
)
# generate workflow for STEP3
mgnify_node
=
self
.
jackhmmer_mgnify_factory
.
gen_node
(
fasta_path
,
mgnify_out_path
,
output_format
=
"sto"
)
if
not
self
.
use_small_bfd
:
# Run HHBlits on BFD
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
# generate workflow for STEP4
bfd_node
=
self
.
hhblits_bfd_factory
.
gen_node
(
fasta_path
,
bfd_out_path
)
else
:
# Run Jackhmmer on small_bfd
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.sto"
)
# generate workflow for STEP4_2
bfd_node
=
self
.
jackhmmer_small_bfd_factory
.
gen_node
(
fasta_path
,
bfd_out_path
,
output_format
=
"sto"
)
# Run JackHmmer on UNIPROT
uniprot_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
# generate workflow for STEP5
uniprot_node
=
self
.
jackhmmer_uniprot_factory
.
gen_node
(
fasta_path
,
uniprot_out_path
,
output_format
=
"sto"
)
# run workflow
batch_run
(
workflow_id
=
workflow_id
,
dags
=
[
hmm_node
,
mgnify_node
,
bfd_node
,
uniprot_node
])
return
\ No newline at end of file
inference.py
View file @
6835c248
...
@@ -26,6 +26,7 @@ import numpy as np
...
@@ -26,6 +26,7 @@ import numpy as np
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
pickle
import
pickle
import
shutil
from
fastfold.model.hub
import
AlphaFold
from
fastfold.model.hub
import
AlphaFold
import
fastfold
import
fastfold
...
@@ -35,7 +36,8 @@ from fastfold.config import model_config
...
@@ -35,7 +36,8 @@ from fastfold.config import model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
,
FastFoldMultimerDataWorkFlow
from
fastfold.utils
import
inject_fastnn
from
fastfold.utils
import
inject_fastnn
from
fastfold.data.parsers
import
parse_fasta
from
fastfold.data.parsers
import
parse_fasta
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.import_weights
import
import_jax_weights_
...
@@ -145,15 +147,6 @@ def inference_multimer_model(args):
...
@@ -145,15 +147,6 @@ def inference_multimer_model(args):
predict_max_templates
=
4
predict_max_templates
=
4
if
not
args
.
use_precomputed_alignments
:
template_searcher
=
hmmsearch
.
Hmmsearch
(
binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
database_path
=
args
.
pdb_seqres_database_path
,
)
else
:
template_searcher
=
None
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
...
@@ -164,17 +157,36 @@ def inference_multimer_model(args):
...
@@ -164,17 +157,36 @@ def inference_multimer_model(args):
)
)
if
(
not
args
.
use_precomputed_alignments
):
if
(
not
args
.
use_precomputed_alignments
):
if
args
.
enable_workflow
:
print
(
"Running alignment with ray workflow..."
)
alignment_runner
=
FastFoldMultimerDataWorkFlow
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hmmsearch_binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
pdb_seqres_database_path
=
args
.
pdb_seqres_database_path
,
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
no_cpus
=
args
.
cpus
)
else
:
alignment_runner
=
data_pipeline
.
AlignmentRunnerMultimer
(
alignment_runner
=
data_pipeline
.
AlignmentRunnerMultimer
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hmmsearch_binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
pdb_seqres_database_path
=
args
.
pdb_seqres_database_path
,
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
)
)
else
:
else
:
alignment_runner
=
None
alignment_runner
=
None
...
@@ -221,12 +233,20 @@ def inference_multimer_model(args):
...
@@ -221,12 +233,20 @@ def inference_multimer_model(args):
if
(
args
.
use_precomputed_alignments
is
None
):
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
else
:
shutil
.
rmtree
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
chain_fasta_str
=
f
'>chain_
{
tag
}
\n
{
seq
}
\n
'
chain_fasta_str
=
f
'>chain_
{
tag
}
\n
{
seq
}
\n
'
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
alignment_runner
.
run
(
if
args
.
enable_workflow
:
chain_fasta_path
,
local_alignment_dir
print
(
"Running alignment with ray workflow..."
)
)
t
=
time
.
perf_counter
()
alignment_runner
.
run
(
chain_fasta_path
,
alignment_dir
=
local_alignment_dir
)
print
(
f
"Alignment data workflow time:
{
time
.
perf_counter
()
-
t
}
"
)
else
:
alignment_runner
.
run
(
chain_fasta_path
,
local_alignment_dir
)
print
(
f
"Finished running alignment for
{
tag
}
"
)
print
(
f
"Finished running alignment for
{
tag
}
"
)
local_alignment_dir
=
alignment_dir
local_alignment_dir
=
alignment_dir
...
@@ -351,7 +371,7 @@ def inference_monomer_model(args):
...
@@ -351,7 +371,7 @@ def inference_monomer_model(args):
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
alignment_data_workflow_runner
.
run
(
fasta_path
,
output_dir
=
output_dir_base
,
alignment_dir
=
local_alignment_dir
)
alignment_data_workflow_runner
.
run
(
fasta_path
,
alignment_dir
=
local_alignment_dir
)
print
(
f
"Alignment data workflow time:
{
time
.
perf_counter
()
-
t
}
"
)
print
(
f
"Alignment data workflow time:
{
time
.
perf_counter
()
-
t
}
"
)
else
:
else
:
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment