Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
71a274d8
Commit
71a274d8
authored
Aug 24, 2022
by
Sam DeLuca
Browse files
adding a script for threading a sequence onto a structure
parent
29864369
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
544 additions
and
261 deletions
+544
-261
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+36
-0
openfold/data/templates.py
openfold/data/templates.py
+50
-0
openfold/model/__init__.py
openfold/model/__init__.py
+16
-16
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+256
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+17
-245
thread_sequence.py
thread_sequence.py
+169
-0
No files found.
openfold/data/data_pipeline.py
View file @
71a274d8
...
...
@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any
import
numpy
as
np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data.templates
import
get_custom_template_features
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
...
...
@@ -259,6 +260,41 @@ def make_msa_features(
return
features
def
make_sequence_features_with_custom_template
(
sequence
:
str
,
mmcif_path
:
str
,
pdb_id
:
str
,
chain_id
:
str
,
kalign_binary_path
:
str
)
->
FeatureDict
:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res
=
len
(
sequence
)
sequence_features
=
make_sequence_features
(
sequence
=
sequence
,
description
=
pdb_id
,
num_res
=
num_res
,
)
msa_data
=
[[
sequence
]]
deletion_matrix
=
[[[
0
for
_
in
sequence
]]]
msa_features
=
make_msa_features
(
msa_data
,
deletion_matrix
)
template_features
=
get_custom_template_features
(
mmcif_path
=
mmcif_path
,
query_sequence
=
sequence
,
pdb_id
=
pdb_id
,
chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
)
return
{
**
sequence_features
,
**
msa_features
,
**
template_features
.
features
}
class
AlignmentRunner
:
"""Runs alignment tools and saves the results"""
def
__init__
(
...
...
openfold/data/templates.py
View file @
71a274d8
...
...
@@ -913,6 +913,56 @@ def _process_single_hit(
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
def
get_custom_template_features
(
mmcif_path
:
str
,
query_sequence
:
str
,
pdb_id
:
str
,
chain_id
:
str
,
kalign_binary_path
:
str
):
with
open
(
mmcif_path
,
"r"
)
as
mmcif_path
:
cif_string
=
mmcif_path
.
read
()
mmcif_parse_result
=
mmcif_parsing
.
parse
(
file_id
=
pdb_id
,
mmcif_string
=
cif_string
)
template_sequence
=
mmcif_parse_result
.
mmcif_object
.
chain_to_seqres
[
chain_id
]
mapping
=
{
x
:
x
for
x
,
_
in
enumerate
(
query_sequence
)}
features
,
warnings
=
_extract_template_features
(
mmcif_object
=
mmcif_parse_result
.
mmcif_object
,
pdb_id
=
pdb_id
,
mapping
=
mapping
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
True
)
features
[
"template_sum_probs"
]
=
[
1.0
]
# TODO: clean up this logic
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
for
k
in
template_features
:
template_features
[
k
].
append
(
features
[
k
])
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
None
,
warnings
=
warnings
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateSearchResult
:
features
:
Mapping
[
str
,
Any
]
...
...
openfold/model/__init__.py
View file @
71a274d8
import
os
import
glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
#
import os
#
import glob
#
import importlib as importlib
#
#
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
#
__all__ = [
#
os.path.basename(f)[:-3]
#
for f in _files
#
if os.path.isfile(f) and not f.endswith("__init__.py")
#
]
#
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
#
for _m in _modules:
#
globals()[_m[0]] = _m[1]
#
#
# Avoid needlessly cluttering the global namespace
#
del _files, _m, _modules
openfold/utils/script_utils.py
0 → 100644
View file @
71a274d8
import
json
import
logging
import
os
import
re
import
time
import
numpy
import
torch
from
openfold.model.model
import
AlphaFold
from
openfold.np
import
residue_constants
,
protein
from
openfold.np.relax
import
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
load_models_from_command_line
(
config
,
model_device
,
openfold_checkpoint_path
,
jax_param_path
,
output_dir
):
# Create the output directory
multiple_model_mode
=
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
jax_param_path
:
for
path
in
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
model_version
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
openfold_checkpoint_path
:
for
path
in
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
get_model_basename
(
path
)
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
path
,
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
not
jax_param_path
and
not
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def
parse_fasta
(
data
):
data
=
re
.
sub
(
'>$'
,
''
,
data
,
flags
=
re
.
M
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
return
tags
,
seqs
def
update_timings
(
timing_dict
,
output_file
=
os
.
path
.
join
(
os
.
getcwd
(),
"timings.json"
)):
"""
Write dictionary of one or more run step times to a file
"""
if
os
.
path
.
exists
(
output_file
):
with
open
(
output_file
,
"r"
)
as
f
:
try
:
timings
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
logger
.
info
(
f
"Overwriting non-standard JSON in
{
output_file
}
."
)
timings
=
{}
else
:
timings
=
{}
timings
.
update
(
timing_dict
)
with
open
(
output_file
,
"w"
)
as
f
:
json
.
dump
(
timings
,
f
)
return
output_file
def
run_model
(
model
,
batch
,
tag
,
output_dir
):
with
torch
.
no_grad
():
# Temporarily disable templates if there aren't any in the batch
template_enabled
=
model
.
config
.
template
.
enabled
model
.
config
.
template
.
enabled
=
template_enabled
and
any
([
"template_"
in
k
for
k
in
batch
])
logger
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
inference_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Inference time:
{
inference_time
}
"
)
update_timings
({
"inference"
:
inference_time
},
os
.
path
.
join
(
output_dir
,
"timings.json"
))
model
.
config
.
template
.
enabled
=
template_enabled
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
config_preset
,
multimer_ri_gap
,
subtract_plddt
):
plddt
=
out
[
"plddt"
]
plddt_b_factors
=
numpy
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
if
subtract_plddt
:
plddt_b_factors
=
100
-
plddt_b_factors
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
:
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
"template_chain_index"
in
feature_dict
:
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
config_preset
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
numpy
.
arange
(
ri
.
shape
[
0
]))
/
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
numpy
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
c
!=
cur_chain
:
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
relax_protein
(
config
,
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
):
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
model_device
!=
"cpu"
),
**
config
.
relax
,
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
model_device
:
device_no
=
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
relaxation_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Relaxation time:
{
relaxation_time
}
"
)
update_timings
({
"relaxation"
:
relaxation_time
},
os
.
path
.
join
(
output_directory
,
"timings.json"
))
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
\ No newline at end of file
run_pretrained_openfold.py
View file @
71a274d8
...
...
@@ -13,26 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
copy
import
deepcopy
from
datetime
import
date
import
logging
import
math
import
numpy
as
np
import
os
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
update_timings
,
relax_protein
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
import
pickle
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
import
random
import
sys
import
time
import
torch
import
re
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
...
...
@@ -46,15 +43,11 @@ if(
torch
.
set_grad_enabled
(
False
)
from
openfold.config
import
model_config
,
NUM_RES
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
...
...
@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen):
return
int
(
math
.
ceil
(
seqlen
/
TRACING_INTERVAL
))
*
TRACING_INTERVAL
def
run_model
(
model
,
batch
,
tag
,
args
):
with
torch
.
no_grad
():
# Temporarily disable templates if there aren't any in the batch
template_enabled
=
model
.
config
.
template
.
enabled
model
.
config
.
template
.
enabled
=
template_enabled
and
any
([
"template_"
in
k
for
k
in
batch
])
logger
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
inference_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Inference time:
{
inference_time
}
"
)
update_timings
({
"inference"
:
inference_time
},
os
.
path
.
join
(
args
.
output_dir
,
"timings.json"
))
model
.
config
.
template
.
enabled
=
template_enabled
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
):
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
if
(
args
.
subtract_plddt
):
plddt_b_factors
=
100
-
plddt_b_factors
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
(
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
):
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
(
"template_chain_index"
in
feature_dict
):
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
args
.
config_preset
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
np
.
arange
(
ri
.
shape
[
0
]))
/
args
.
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
np
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
(
c
!=
cur_chain
):
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
args
.
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
parse_fasta
(
data
):
data
=
re
.
sub
(
'>$'
,
''
,
data
,
flags
=
re
.
M
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
return
tags
,
seqs
def
generate_feature_dict
(
tags
,
seqs
,
...
...
@@ -235,98 +132,6 @@ def generate_feature_dict(
return
feature_dict
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
multiple_model_mode
=
count_models_to_evaluate
(
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
args
.
jax_param_path
:
for
path
in
args
.
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
model_version
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
args
.
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
get_model_basename
(
path
)
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
path
,
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
args
.
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def
list_files_with_extensions
(
dir
,
extensions
):
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
...
...
@@ -389,7 +194,13 @@ def main(args):
seq_sort_fn
=
lambda
target
:
sum
([
len
(
s
)
for
s
in
target
[
1
]])
sorted_targets
=
sorted
(
zip
(
tag_list
,
seq_list
),
key
=
seq_sort_fn
)
feature_dicts
=
{}
for
model
,
output_directory
in
load_models_from_command_line
(
args
,
config
):
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
for
model
,
output_directory
in
model_generator
:
cur_tracing_interval
=
0
for
(
tag
,
tags
),
seqs
in
sorted_targets
:
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
...
...
@@ -440,7 +251,7 @@ def main(args):
)
cur_tracing_interval
=
rounded_seqlen
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
)
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict
=
tensor_tree_map
(
...
...
@@ -454,7 +265,8 @@ def main(args):
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
args
.
config_preset
,
args
.
multimer_ri_gap
)
unrelaxed_output_path
=
os
.
path
.
join
(
...
...
@@ -467,33 +279,9 @@ def main(args):
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
args
.
model_device
:
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
relaxation_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Relaxation time:
{
relaxation_time
}
"
)
update_timings
({
"relaxation"
:
relaxation_time
},
os
.
path
.
join
(
args
.
output_dir
,
"timings.json"
))
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
...
...
@@ -504,22 +292,6 @@ def main(args):
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
def
update_timings
(
dict
,
output_file
=
os
.
path
.
join
(
os
.
getcwd
(),
"timings.json"
)):
"""Write dictionary of one or more run step times to a file"""
import
json
if
os
.
path
.
exists
(
output_file
):
with
open
(
output_file
,
"r"
)
as
f
:
try
:
timings
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
logger
.
info
(
f
"Overwriting non-standard JSON in
{
output_file
}
."
)
timings
=
{}
else
:
timings
=
{}
timings
.
update
(
dict
)
with
open
(
output_file
,
"w"
)
as
f
:
json
.
dump
(
timings
,
f
)
return
output_file
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
...
...
thread_sequence.py
0 → 100644
View file @
71a274d8
import
argparse
import
os
import
logging
import
random
import
numpy
import
torch
from
openfold.config
import
model_config
from
openfold.data
import
feature_pipeline
from
openfold.data.data_pipeline
import
make_sequence_features_with_custom_template
from
openfold.np
import
protein
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
relax_protein
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
scripts.utils
import
add_data_args
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
# Gives a large speedup on Ampere-class GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
set_grad_enabled
(
False
)
def
main
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
)
#TODO make configurable
random_seed
=
random
.
randrange
(
2
**
32
)
numpy
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
with
open
(
args
.
input_fasta
)
as
fasta_file
:
tags
,
sequences
=
parse_fasta
(
fasta_file
.
read
())
if
len
(
sequences
)
!=
1
:
raise
ValueError
(
"the threading script can only process a single sequence"
)
query_sequence
=
sequences
[
0
]
query_tag
=
tags
[
0
]
feature_dict
=
make_sequence_features_with_custom_template
(
query_sequence
,
args
.
input_mmcif
,
args
.
template_id
,
args
.
chain_id
,
args
.
kalign_binary_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
processed_feature_dict
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
processed_feature_dict
.
items
()
}
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
output_name
=
f
'
{
query_tag
}
_
{
args
.
config_preset
}
'
for
model
,
output_directory
in
model_generator
:
out
=
run_model
(
model
,
processed_feature_dict
,
query_tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
[...,
-
1
].
cpu
()),
processed_feature_dict
)
out
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
out
,
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
.
config_preset
,
200
,
# this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args
.
subtract_plddt
)
unrelaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_fasta"
,
type
=
str
,
help
=
"the path to a fasta file containing a single sequence to thread"
)
parser
.
add_argument
(
"input_mmcif"
,
type
=
str
,
help
=
"the path to an mmcif file to thread the sequence on to"
)
parser
.
add_argument
(
"--template_id"
,
type
=
str
,
help
=
"a PDB id or other identifier for the template"
)
parser
.
add_argument
(
"--chain_id"
,
type
=
str
,
help
=
"""The chain ID of the chain in the template to use"""
)
parser
.
add_argument
(
"--model_device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config preset defined in openfold/config.py"""
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--openfold_checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
)
parser
.
add_argument
(
"--subtract_plddt"
,
action
=
"store_true"
,
default
=
False
,
help
=
""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
jax_param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
config_preset
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
logging
.
warning
(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main
(
args
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment