Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
17f24bd7
Commit
17f24bd7
authored
Feb 20, 2024
by
rostro36
Browse files
Added custom template folder
parent
bb3f51e5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
52 deletions
+101
-52
README.md
README.md
+4
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+16
-3
openfold/data/templates.py
openfold/data/templates.py
+68
-41
run_pretrained_openfold.py
run_pretrained_openfold.py
+13
-7
No files found.
README.md
View file @
17f24bd7
...
...
@@ -174,7 +174,10 @@ where `data` is the same directory as in the previous step. If `jackhmmer`,
`/usr/bin`
, their
`binary_path`
command-line arguments can be dropped.
If you've already computed alignments for the query, you have the option to
skip the expensive alignment computation here with
`--use_precomputed_alignments`
.
`--use_precomputed_alignments`
. If you wish to use a specific template as input,
you can use the argument
`--use_custom_template`
, which then will read all .cif
files in
`template_mmcif_dir`
. Make sure the chains of interest have the identifier _A_
and have the same length as the input sequence.
`--openfold_checkpoint_path`
or
`--jax_param_path`
accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
...
...
openfold/data/data_pipeline.py
View file @
17f24bd7
...
...
@@ -23,8 +23,19 @@ import tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
import
numpy
as
np
import
torch
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data.templates
import
get_custom_template_features
,
empty_template_feats
from
openfold.data
import
(
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
,
)
from
openfold.data.templates
import
(
get_custom_template_features
,
empty_template_feats
,
CustomHitFeaturizer
,
)
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.np
import
residue_constants
,
protein
...
...
@@ -38,7 +49,9 @@ def make_template_features(
template_featurizer
:
Any
,
)
->
FeatureDict
:
hits_cat
=
sum
(
hits
.
values
(),
[])
if
(
len
(
hits_cat
)
==
0
or
template_featurizer
is
None
):
if
template_featurizer
is
None
or
(
len
(
hits_cat
)
==
0
and
not
isinstance
(
template_featurizer
,
CustomHitFeaturizer
)
):
template_features
=
empty_template_feats
(
len
(
input_sequence
))
else
:
templates_result
=
template_featurizer
.
get_templates
(
...
...
openfold/data/templates.py
View file @
17f24bd7
...
...
@@ -22,6 +22,7 @@ import glob
import
json
import
logging
import
os
from
pathlib
import
Path
import
re
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
...
...
@@ -947,49 +948,58 @@ def _process_single_hit(
def
get_custom_template_features
(
mmcif_path
:
str
,
query_sequence
:
str
,
pdb_id
:
str
,
chain_id
:
str
,
kalign_binary_path
:
str
):
with
open
(
mmcif_path
,
"r"
)
as
mmcif_path
:
cif_string
=
mmcif_path
.
read
()
mmcif_parse_result
=
mmcif_parsing
.
parse
(
file_id
=
pdb_id
,
mmcif_string
=
cif_string
)
template_sequence
=
mmcif_parse_result
.
mmcif_object
.
chain_to_seqres
[
chain_id
]
mapping
=
{
x
:
x
for
x
,
_
in
enumerate
(
query_sequence
)}
features
,
warnings
=
_extract_template_features
(
mmcif_object
=
mmcif_parse_result
.
mmcif_object
,
pdb_id
=
pdb_id
,
mapping
=
mapping
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
True
)
features
[
"template_sum_probs"
]
=
[
1.0
]
# TODO: clean up this logic
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
for
k
in
template_features
:
template_features
[
k
].
append
(
features
[
k
])
mmcif_path
:
str
,
query_sequence
:
str
,
pdb_id
:
str
,
chain_id
:
Optional
[
str
]
=
"A"
,
kalign_binary_path
:
Optional
[
str
]
=
None
,
):
if
os
.
path
.
isfile
(
mmcif_path
):
template_paths
=
[
Path
(
mmcif_path
)]
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
elif
os
.
path
.
isdir
(
mmcif_path
):
template_paths
=
list
(
Path
(
mmcif_path
).
glob
(
"*.cif"
))
else
:
logging
.
error
(
"Custom template path %s does not exist"
,
mmcif_path
)
raise
ValueError
(
f
"Custom template path
{
mmcif_path
}
does not exist"
)
warnings
=
[]
template_features
=
dict
()
for
template_path
in
template_paths
:
logging
.
info
(
"Featurizing template: %s"
,
template_path
)
# pdb_id only for error reporting, take file name
pdb_id
=
Path
(
template_path
).
stem
with
open
(
template_path
,
"r"
)
as
mmcif_path
:
cif_string
=
mmcif_path
.
read
()
mmcif_parse_result
=
mmcif_parsing
.
parse
(
file_id
=
pdb_id
,
mmcif_string
=
cif_string
)
# chain_id defaults to A, should be changed?
template_sequence
=
mmcif_parse_result
.
mmcif_object
.
chain_to_seqres
[
chain_id
]
mapping
=
{
x
:
x
for
x
,
_
in
enumerate
(
query_sequence
)}
curr_features
,
curr_warnings
=
_extract_template_features
(
mmcif_object
=
mmcif_parse_result
.
mmcif_object
,
pdb_id
=
pdb_id
,
mapping
=
mapping
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
True
,
)
curr_features
[
"template_sum_probs"
]
=
[
1.0
]
template_features
=
{
curr_name
:
template_features
.
get
(
curr_name
,
[])
+
[
curr_item
]
for
curr_name
,
curr_item
in
curr_features
.
items
()
}
warnings
=
warnings
.
append
(
curr_warnings
)
template_features
=
{
template_feature_name
:
np
.
stack
(
template_features
[
template_feature_name
],
axis
=
0
).
astype
(
template_feature_type
)
for
template_feature_name
,
template_feature_type
in
TEMPLATE_FEATURES
.
items
()
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
None
,
warnings
=
warnings
)
...
...
@@ -1188,6 +1198,23 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
)
class
CustomHitFeaturizer
(
TemplateHitFeaturizer
):
"""Featurizer for templates given in folder.
Chain of interest has to be chain A and of same residue size as input sequence."""
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Featurizing mmcif_dir: %s"
,
self
.
_mmcif_dir
)
return
get_custom_template_features
(
self
.
_mmcif_dir
,
query_sequence
=
query_sequence
,
pdb_id
=
"test"
,
chain_id
=
"A"
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
...
...
run_pretrained_openfold.py
View file @
17f24bd7
...
...
@@ -186,8 +186,15 @@ def main(args):
)
is_multimer
=
"multimer"
in
args
.
config_preset
if
is_multimer
:
is_custom_template
=
"use_custom_template"
in
args
if
is_custom_template
:
template_featurizer
=
templates
.
CustomHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
"9999-12-31"
,
# just dummy, not used
max_hits
=-
1
,
# just dummy, not used
kalign_binary_path
=
args
.
kalign_binary_path
)
elif
is_multimer
:
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
...
...
@@ -205,11 +212,9 @@ def main(args):
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
)
if
is_multimer
:
data_processor
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
data_processor
,
...
...
@@ -222,7 +227,6 @@ def main(args):
np
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
...
...
@@ -292,7 +296,6 @@ def main(args):
)
feature_dicts
[
tag
]
=
feature_dict
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
is_multimer
)
...
...
@@ -379,6 +382,10 @@ if __name__ == "__main__":
help
=
"""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored."""
)
parser
.
add_argument
(
"--use_custom_template"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Use mmcif given with "template_mmcif_dir" argument as template input."""
)
parser
.
add_argument
(
"--use_single_seq_mode"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Use single sequence embeddings instead of MSAs."""
...
...
@@ -466,5 +473,4 @@ if __name__ == "__main__":
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment