Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
7d23b654
Commit
7d23b654
authored
Jun 29, 2022
by
Sam DeLuca
Browse files
improving file naming in run_pretrained_openfold
parent
080e37bb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
31 deletions
+49
-31
run_pretrained_openfold.py
run_pretrained_openfold.py
+49
-31
No files found.
run_pretrained_openfold.py
View file @
7d23b654
...
@@ -208,36 +208,57 @@ def generate_feature_dict(
...
@@ -208,36 +208,57 @@ def generate_feature_dict(
return
feature_dict
return
feature_dict
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
load_models_from_command_line
(
args
,
config
):
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
multiple_model_mode
=
count_models_to_evaluate
(
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
args
.
jax_param_path
:
if
args
.
jax_param_path
:
for
path
in
args
.
jax_param_path
.
split
(
","
):
for
path
in
args
.
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
import_jax_weights_
(
import_jax_weights_
(
model
,
path
,
version
=
args
.
config_preset
model
,
path
,
version
=
model_version
)
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
model_version
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
jax_param_path
),
)
)
model_version
=
os
.
path
.
splitext
(
model_version
)[
0
]
output_directory
=
make_output_directory
(
args
.
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
model_version
yield
model
,
output_directory
if
args
.
openfold_checkpoint_path
:
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
checkpoint_basename
=
os
.
path
.
splitext
(
checkpoint_basename
=
get_model_basename
(
path
)
os
.
path
.
basename
(
os
.
path
.
normpath
(
path
)
)
)[
0
]
if
os
.
path
.
isdir
(
path
):
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
ckpt_path
=
os
.
path
.
join
(
...
@@ -256,17 +277,17 @@ def load_models_from_command_line(args, config):
...
@@ -256,17 +277,17 @@ def load_models_from_command_line(args, config):
ckpt_path
=
path
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
if
(
"ema"
in
d
)
:
if
"ema"
in
d
:
# The public weights have had this done to them already
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
logger
.
info
(
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_
path
}
..."
f
"Loaded OpenFold parameters at
{
path
}
..."
)
)
output_directory
=
make_output_directory
(
args
.
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
checkpoint_basename
yield
model
,
output_directory
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
raise
ValueError
(
...
@@ -308,9 +329,6 @@ def main(args):
...
@@ -308,9 +329,6 @@ def main(args):
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
# Gather input sequences
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
...
@@ -323,14 +341,6 @@ def main(args):
...
@@ -323,14 +341,6 @@ def main(args):
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
# Output already exists
if
os
.
path
.
exists
(
unrelaxed_output_path
):
continue
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
...
@@ -346,7 +356,7 @@ def main(args):
...
@@ -346,7 +356,7 @@ def main(args):
feature_dict
,
mode
=
'predict'
,
feature_dict
,
mode
=
'predict'
,
)
)
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
for
model
,
output_directory
in
load_models_from_command_line
(
args
,
config
):
working_batch
=
deepcopy
(
processed_feature_dict
)
working_batch
=
deepcopy
(
processed_feature_dict
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
...
@@ -358,6 +368,14 @@ def main(args):
...
@@ -358,6 +368,14 @@ def main(args):
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
)
)
unrelaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
# Output already exists
if
os
.
path
.
exists
(
unrelaxed_output_path
):
continue
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
...
@@ -382,7 +400,7 @@ def main(args):
...
@@ -382,7 +400,7 @@ def main(args):
# Save the relaxed PDB.
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
relaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_relaxed.pdb'
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
fp
.
write
(
relaxed_pdb_str
)
...
@@ -391,7 +409,7 @@ def main(args):
...
@@ -391,7 +409,7 @@ def main(args):
if
args
.
save_outputs
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
output_dir
ectory
,
f
'
{
output_name
}
_output_dict.pkl'
)
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment