Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1fa6ffab
Commit
1fa6ffab
authored
Jun 22, 2022
by
Gustaf Ahdritz
Browse files
Refactor run_pretrained_openfold.py a little
parent
b4b849af
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
51 deletions
+56
-51
run_pretrained_openfold.py
run_pretrained_openfold.py
+56
-51
No files found.
run_pretrained_openfold.py
View file @
1fa6ffab
...
@@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return
unrelaxed_protein
return
unrelaxed_protein
def
generate_batch
(
fasta_file
,
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
,
prediction_dir
):
def
parse_fasta
(
data
):
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
if
os
.
path
.
exists
(
unrelaxed_output_path
):
return
tags
,
seqs
return
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
def
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
data_processor
,
args
,
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
if
len
(
seqs
)
==
1
:
tag
=
tags
[
0
]
seq
=
seqs
[
0
]
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
...
@@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
...
@@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
# Remove temporary FASTA file
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
return
feature_dict
feature_dict
,
mode
=
'predict'
,
)
return
processed_feature_dict
,
tag
,
feature_dict
def
load_models_from_command_line
(
args
,
config
):
def
load_models_from_command_line
(
args
,
config
):
...
@@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
...
@@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
import_jax_weights_
(
import_jax_weights_
(
model
,
path
,
version
=
args
.
model_name
model
,
path
,
version
=
args
.
config_preset
)
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
)
)
yield
model
,
None
model_version
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
jax_param_path
),
)
model_version
=
os
.
path
.
splitext
(
model_version
)[
0
]
yield
model
,
model_version
if
args
.
openfold_checkpoint_path
:
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
...
@@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
...
@@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
# The public weights have had this done to them already
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
logger
.
info
(
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
)
)
yield
model
,
checkpoint_basename
yield
model
,
checkpoint_basename
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"At least one of jax_param_path or openfold_checkpoint_path must "
...
@@ -311,23 +308,40 @@ def main(args):
...
@@ -311,23 +308,40 @@ def main(args):
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
tags
,
seqs
=
parse_fasta
(
data
)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
# Output already exists
if
os
.
path
.
exists
(
unrelaxed_output_path
):
continue
batch_data
=
generate_batch
(
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
fasta_file
,
args
.
fasta_dir
,
feature_dict
=
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
alignment_dir
,
data_processor
,
data_processor
,
feature_processor
,
args
,
prediction_dir
)
)
if
batch_data
is
None
:
# this file has already been processed
continue
batch
,
tag
,
feature_dict
=
batch_data
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
working_batch
=
deepcopy
(
batch
)
working_batch
=
deepcopy
(
batch
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
...
@@ -339,21 +353,11 @@ def main(args):
...
@@ -339,21 +353,11 @@ def main(args):
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
)
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
model_version
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
model_version
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
...
@@ -377,6 +381,7 @@ def main(args):
...
@@ -377,6 +381,7 @@ def main(args):
)
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
if
args
.
save_outputs
:
if
args
.
save_outputs
:
...
@@ -388,6 +393,7 @@ def main(args):
...
@@ -388,6 +393,7 @@ def main(args):
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -413,8 +419,7 @@ if __name__ == "__main__":
...
@@ -413,8 +419,7 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
help
=
"""Name of a model config preset defined in openfold/config.py"""
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
"--jax_param_path"
,
type
=
str
,
default
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment