Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
71a274d8
Commit
71a274d8
authored
Aug 24, 2022
by
Sam DeLuca
Browse files
adding a script for threading a sequence onto a structure
parent
29864369
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
544 additions
and
261 deletions
+544
-261
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+36
-0
openfold/data/templates.py
openfold/data/templates.py
+50
-0
openfold/model/__init__.py
openfold/model/__init__.py
+16
-16
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+256
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+17
-245
thread_sequence.py
thread_sequence.py
+169
-0
No files found.
openfold/data/data_pipeline.py
View file @
71a274d8
...
@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any
...
@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any
import
numpy
as
np
import
numpy
as
np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data.templates
import
get_custom_template_features
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
...
@@ -259,6 +260,41 @@ def make_msa_features(
...
@@ -259,6 +260,41 @@ def make_msa_features(
return
features
return
features
def
make_sequence_features_with_custom_template
(
sequence
:
str
,
mmcif_path
:
str
,
pdb_id
:
str
,
chain_id
:
str
,
kalign_binary_path
:
str
)
->
FeatureDict
:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res
=
len
(
sequence
)
sequence_features
=
make_sequence_features
(
sequence
=
sequence
,
description
=
pdb_id
,
num_res
=
num_res
,
)
msa_data
=
[[
sequence
]]
deletion_matrix
=
[[[
0
for
_
in
sequence
]]]
msa_features
=
make_msa_features
(
msa_data
,
deletion_matrix
)
template_features
=
get_custom_template_features
(
mmcif_path
=
mmcif_path
,
query_sequence
=
sequence
,
pdb_id
=
pdb_id
,
chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
)
return
{
**
sequence_features
,
**
msa_features
,
**
template_features
.
features
}
class
AlignmentRunner
:
class
AlignmentRunner
:
"""Runs alignment tools and saves the results"""
"""Runs alignment tools and saves the results"""
def
__init__
(
def
__init__
(
...
...
openfold/data/templates.py
View file @
71a274d8
...
@@ -913,6 +913,56 @@ def _process_single_hit(
...
@@ -913,6 +913,56 @@ def _process_single_hit(
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
def
get_custom_template_features
(
mmcif_path
:
str
,
query_sequence
:
str
,
pdb_id
:
str
,
chain_id
:
str
,
kalign_binary_path
:
str
):
with
open
(
mmcif_path
,
"r"
)
as
mmcif_path
:
cif_string
=
mmcif_path
.
read
()
mmcif_parse_result
=
mmcif_parsing
.
parse
(
file_id
=
pdb_id
,
mmcif_string
=
cif_string
)
template_sequence
=
mmcif_parse_result
.
mmcif_object
.
chain_to_seqres
[
chain_id
]
mapping
=
{
x
:
x
for
x
,
_
in
enumerate
(
query_sequence
)}
features
,
warnings
=
_extract_template_features
(
mmcif_object
=
mmcif_parse_result
.
mmcif_object
,
pdb_id
=
pdb_id
,
mapping
=
mapping
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
True
)
features
[
"template_sum_probs"
]
=
[
1.0
]
# TODO: clean up this logic
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
for
k
in
template_features
:
template_features
[
k
].
append
(
features
[
k
])
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
None
,
warnings
=
warnings
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateSearchResult
:
class
TemplateSearchResult
:
features
:
Mapping
[
str
,
Any
]
features
:
Mapping
[
str
,
Any
]
...
...
openfold/model/__init__.py
View file @
71a274d8
import
os
#
import os
import
glob
#
import glob
import
importlib
as
importlib
#
import importlib as importlib
#
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
#
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__
=
[
#
__all__ = [
os
.
path
.
basename
(
f
)[:
-
3
]
#
os.path.basename(f)[:-3]
for
f
in
_files
#
for f in _files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
#
if os.path.isfile(f) and not f.endswith("__init__.py")
]
#
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
#
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for
_m
in
_modules
:
#
for _m in _modules:
globals
()[
_m
[
0
]]
=
_m
[
1
]
#
globals()[_m[0]] = _m[1]
#
# Avoid needlessly cluttering the global namespace
#
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
#
del _files, _m, _modules
openfold/utils/script_utils.py
0 → 100644
View file @
71a274d8
import
json
import
logging
import
os
import
re
import
time
import
numpy
import
torch
from
openfold.model.model
import
AlphaFold
from
openfold.np
import
residue_constants
,
protein
from
openfold.np.relax
import
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
load_models_from_command_line
(
config
,
model_device
,
openfold_checkpoint_path
,
jax_param_path
,
output_dir
):
# Create the output directory
multiple_model_mode
=
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
jax_param_path
:
for
path
in
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
model_version
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
openfold_checkpoint_path
:
for
path
in
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
get_model_basename
(
path
)
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
path
,
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
not
jax_param_path
and
not
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def
parse_fasta
(
data
):
data
=
re
.
sub
(
'>$'
,
''
,
data
,
flags
=
re
.
M
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
return
tags
,
seqs
def
update_timings
(
timing_dict
,
output_file
=
os
.
path
.
join
(
os
.
getcwd
(),
"timings.json"
)):
"""
Write dictionary of one or more run step times to a file
"""
if
os
.
path
.
exists
(
output_file
):
with
open
(
output_file
,
"r"
)
as
f
:
try
:
timings
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
logger
.
info
(
f
"Overwriting non-standard JSON in
{
output_file
}
."
)
timings
=
{}
else
:
timings
=
{}
timings
.
update
(
timing_dict
)
with
open
(
output_file
,
"w"
)
as
f
:
json
.
dump
(
timings
,
f
)
return
output_file
def
run_model
(
model
,
batch
,
tag
,
output_dir
):
with
torch
.
no_grad
():
# Temporarily disable templates if there aren't any in the batch
template_enabled
=
model
.
config
.
template
.
enabled
model
.
config
.
template
.
enabled
=
template_enabled
and
any
([
"template_"
in
k
for
k
in
batch
])
logger
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
inference_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Inference time:
{
inference_time
}
"
)
update_timings
({
"inference"
:
inference_time
},
os
.
path
.
join
(
output_dir
,
"timings.json"
))
model
.
config
.
template
.
enabled
=
template_enabled
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
config_preset
,
multimer_ri_gap
,
subtract_plddt
):
plddt
=
out
[
"plddt"
]
plddt_b_factors
=
numpy
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
if
subtract_plddt
:
plddt_b_factors
=
100
-
plddt_b_factors
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
:
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
"template_chain_index"
in
feature_dict
:
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
config_preset
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
numpy
.
arange
(
ri
.
shape
[
0
]))
/
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
numpy
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
c
!=
cur_chain
:
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
relax_protein
(
config
,
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
):
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
model_device
!=
"cpu"
),
**
config
.
relax
,
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
model_device
:
device_no
=
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
relaxation_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Relaxation time:
{
relaxation_time
}
"
)
update_timings
({
"relaxation"
:
relaxation_time
},
os
.
path
.
join
(
output_directory
,
"timings.json"
))
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
\ No newline at end of file
run_pretrained_openfold.py
View file @
71a274d8
...
@@ -13,26 +13,23 @@
...
@@ -13,26 +13,23 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
from
copy
import
deepcopy
from
datetime
import
date
import
logging
import
logging
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
update_timings
,
relax_protein
logging
.
basicConfig
()
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
import
pickle
import
pickle
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
import
random
import
random
import
sys
import
time
import
time
import
torch
import
torch
import
re
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_major_version
=
int
(
torch_versions
[
0
])
...
@@ -46,15 +43,11 @@ if(
...
@@ -46,15 +43,11 @@ if(
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
from
openfold.config
import
model_config
,
NUM_RES
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
tensor_tree_map
,
)
)
...
@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen):
...
@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen):
return
int
(
math
.
ceil
(
seqlen
/
TRACING_INTERVAL
))
*
TRACING_INTERVAL
return
int
(
math
.
ceil
(
seqlen
/
TRACING_INTERVAL
))
*
TRACING_INTERVAL
def
run_model
(
model
,
batch
,
tag
,
args
):
with
torch
.
no_grad
():
# Temporarily disable templates if there aren't any in the batch
template_enabled
=
model
.
config
.
template
.
enabled
model
.
config
.
template
.
enabled
=
template_enabled
and
any
([
"template_"
in
k
for
k
in
batch
])
logger
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
inference_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Inference time:
{
inference_time
}
"
)
update_timings
({
"inference"
:
inference_time
},
os
.
path
.
join
(
args
.
output_dir
,
"timings.json"
))
model
.
config
.
template
.
enabled
=
template_enabled
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
):
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
if
(
args
.
subtract_plddt
):
plddt_b_factors
=
100
-
plddt_b_factors
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
(
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
):
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
(
"template_chain_index"
in
feature_dict
):
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
args
.
config_preset
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
np
.
arange
(
ri
.
shape
[
0
]))
/
args
.
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
np
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
(
c
!=
cur_chain
):
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
args
.
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
parse_fasta
(
data
):
data
=
re
.
sub
(
'>$'
,
''
,
data
,
flags
=
re
.
M
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
return
tags
,
seqs
def
generate_feature_dict
(
def
generate_feature_dict
(
tags
,
tags
,
seqs
,
seqs
,
...
@@ -235,98 +132,6 @@ def generate_feature_dict(
...
@@ -235,98 +132,6 @@ def generate_feature_dict(
return
feature_dict
return
feature_dict
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
multiple_model_mode
=
count_models_to_evaluate
(
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
args
.
jax_param_path
:
for
path
in
args
.
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
model_version
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
args
.
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
get_model_basename
(
path
)
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
path
,
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
args
.
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def
list_files_with_extensions
(
dir
,
extensions
):
def
list_files_with_extensions
(
dir
,
extensions
):
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
...
@@ -389,7 +194,13 @@ def main(args):
...
@@ -389,7 +194,13 @@ def main(args):
seq_sort_fn
=
lambda
target
:
sum
([
len
(
s
)
for
s
in
target
[
1
]])
seq_sort_fn
=
lambda
target
:
sum
([
len
(
s
)
for
s
in
target
[
1
]])
sorted_targets
=
sorted
(
zip
(
tag_list
,
seq_list
),
key
=
seq_sort_fn
)
sorted_targets
=
sorted
(
zip
(
tag_list
,
seq_list
),
key
=
seq_sort_fn
)
feature_dicts
=
{}
feature_dicts
=
{}
for
model
,
output_directory
in
load_models_from_command_line
(
args
,
config
):
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
for
model
,
output_directory
in
model_generator
:
cur_tracing_interval
=
0
cur_tracing_interval
=
0
for
(
tag
,
tags
),
seqs
in
sorted_targets
:
for
(
tag
,
tags
),
seqs
in
sorted_targets
:
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
...
@@ -440,7 +251,7 @@ def main(args):
...
@@ -440,7 +251,7 @@ def main(args):
)
)
cur_tracing_interval
=
rounded_seqlen
cur_tracing_interval
=
rounded_seqlen
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
)
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict
=
tensor_tree_map
(
processed_feature_dict
=
tensor_tree_map
(
...
@@ -454,7 +265,8 @@ def main(args):
...
@@ -454,7 +265,8 @@ def main(args):
processed_feature_dict
,
processed_feature_dict
,
feature_dict
,
feature_dict
,
feature_processor
,
feature_processor
,
args
args
.
config_preset
,
args
.
multimer_ri_gap
)
)
unrelaxed_output_path
=
os
.
path
.
join
(
unrelaxed_output_path
=
os
.
path
.
join
(
...
@@ -467,33 +279,9 @@ def main(args):
...
@@ -467,33 +279,9 @@ def main(args):
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
t
=
time
.
perf_counter
()
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
args
.
model_device
:
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
relaxation_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Relaxation time:
{
relaxation_time
}
"
)
update_timings
({
"relaxation"
:
relaxation_time
},
os
.
path
.
join
(
args
.
output_dir
,
"timings.json"
))
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
if
args
.
save_outputs
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
...
@@ -504,22 +292,6 @@ def main(args):
...
@@ -504,22 +292,6 @@ def main(args):
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
def
update_timings
(
dict
,
output_file
=
os
.
path
.
join
(
os
.
getcwd
(),
"timings.json"
)):
"""Write dictionary of one or more run step times to a file"""
import
json
if
os
.
path
.
exists
(
output_file
):
with
open
(
output_file
,
"r"
)
as
f
:
try
:
timings
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
logger
.
info
(
f
"Overwriting non-standard JSON in
{
output_file
}
."
)
timings
=
{}
else
:
timings
=
{}
timings
.
update
(
dict
)
with
open
(
output_file
,
"w"
)
as
f
:
json
.
dump
(
timings
,
f
)
return
output_file
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
thread_sequence.py
0 → 100644
View file @
71a274d8
import
argparse
import
os
import
logging
import
random
import
numpy
import
torch
from
openfold.config
import
model_config
from
openfold.data
import
feature_pipeline
from
openfold.data.data_pipeline
import
make_sequence_features_with_custom_template
from
openfold.np
import
protein
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
relax_protein
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
scripts.utils
import
add_data_args
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
# Gives a large speedup on Ampere-class GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
set_grad_enabled
(
False
)
def
main
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
)
#TODO make configurable
random_seed
=
random
.
randrange
(
2
**
32
)
numpy
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
with
open
(
args
.
input_fasta
)
as
fasta_file
:
tags
,
sequences
=
parse_fasta
(
fasta_file
.
read
())
if
len
(
sequences
)
!=
1
:
raise
ValueError
(
"the threading script can only process a single sequence"
)
query_sequence
=
sequences
[
0
]
query_tag
=
tags
[
0
]
feature_dict
=
make_sequence_features_with_custom_template
(
query_sequence
,
args
.
input_mmcif
,
args
.
template_id
,
args
.
chain_id
,
args
.
kalign_binary_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
processed_feature_dict
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
processed_feature_dict
.
items
()
}
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
output_name
=
f
'
{
query_tag
}
_
{
args
.
config_preset
}
'
for
model
,
output_directory
in
model_generator
:
out
=
run_model
(
model
,
processed_feature_dict
,
query_tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
[...,
-
1
].
cpu
()),
processed_feature_dict
)
out
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
out
,
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
.
config_preset
,
200
,
# this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args
.
subtract_plddt
)
unrelaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_fasta"
,
type
=
str
,
help
=
"the path to a fasta file containing a single sequence to thread"
)
parser
.
add_argument
(
"input_mmcif"
,
type
=
str
,
help
=
"the path to an mmcif file to thread the sequence on to"
)
parser
.
add_argument
(
"--template_id"
,
type
=
str
,
help
=
"a PDB id or other identifier for the template"
)
parser
.
add_argument
(
"--chain_id"
,
type
=
str
,
help
=
"""The chain ID of the chain in the template to use"""
)
parser
.
add_argument
(
"--model_device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config preset defined in openfold/config.py"""
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--openfold_checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
)
parser
.
add_argument
(
"--subtract_plddt"
,
action
=
"store_true"
,
default
=
False
,
help
=
""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
jax_param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
config_preset
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
logging
.
warning
(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main
(
args
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment