Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1fa6ffab
"googlemock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "2c19680bf99dbab41f08b4d39f401486d3a74d06"
Commit
1fa6ffab
authored
Jun 22, 2022
by
Gustaf Ahdritz
Browse files
Refactor run_pretrained_openfold.py a little
parent
b4b849af
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
51 deletions
+56
-51
run_pretrained_openfold.py
run_pretrained_openfold.py
+56
-51
No files found.
run_pretrained_openfold.py
View file @
1fa6ffab
...
...
@@ -162,10 +162,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return
unrelaxed_protein
def
generate_batch
(
fasta_file
,
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
,
prediction_dir
):
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
def
parse_fasta
(
data
):
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
...
...
@@ -173,25 +170,20 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
if
os
.
path
.
exists
(
unrelaxed_output_path
):
return
return
tags
,
seqs
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
def
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
data_processor
,
args
,
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
tag
=
tags
[
0
]
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
...
...
@@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
return
processed_feature_dict
,
tag
,
feature_dict
return
feature_dict
def
load_models_from_command_line
(
args
,
config
):
...
...
@@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
args
.
model_name
model
,
path
,
version
=
args
.
config_preset
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
)
yield
model
,
None
model_version
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
jax_param_path
),
)
model_version
=
os
.
path
.
splitext
(
model_version
)[
0
]
yield
model
,
model_version
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
...
...
@@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
)
yield
model
,
checkpoint_basename
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
...
...
@@ -311,23 +308,40 @@ def main(args):
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
batch_data
=
generate_batch
(
fasta_file
,
args
.
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
,
prediction_dir
)
tags
,
seqs
=
parse_fasta
(
data
)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
if
batch_data
is
None
:
# this file has already been processed
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
# Output already exists
if
os
.
path
.
exists
(
unrelaxed_output_path
):
continue
batch
,
tag
,
feature_dict
=
batch_data
precompute_alignments
(
tag
s
,
seqs
,
alignment_dir
,
args
)
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
feature_dict
=
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
data_processor
,
args
,
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
working_batch
=
deepcopy
(
batch
)
out
=
run_model
(
model
,
working_batch
,
tag
,
args
)
...
...
@@ -339,21 +353,11 @@ def main(args):
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
model_version
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
model_version
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
...
...
@@ -377,6 +381,7 @@ def main(args):
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
if
args
.
save_outputs
:
...
...
@@ -388,6 +393,7 @@ def main(args):
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
...
...
@@ -413,8 +419,7 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
help
=
"""Name of a model config preset defined in openfold/config.py"""
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment