Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
15a8c321
"deploy/vscode:/vscode.git/clone" did not exist on "59d916e11b7fb6e593c22370bd54b40645b49b94"
Commit
15a8c321
authored
Jun 16, 2022
by
Sam DeLuca
Browse files
refactor to allow multiple models to run
parent
4bfd5bf0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
134 additions
and
115 deletions
+134
-115
run_pretrained_openfold.py
run_pretrained_openfold.py
+134
-115
No files found.
run_pretrained_openfold.py
View file @
15a8c321
...
@@ -159,24 +159,72 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -159,24 +159,72 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return
unrelaxed_protein
return
unrelaxed_protein
def
main
(
args
):
def
generate_batch
(
fasta_file
,
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
):
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
return
processed_feature_dict
,
tag
,
feature_dict
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
args
.
jax_param_path
:
# Prep the model
for
path
in
args
.
jax_param_path
.
split
(
","
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
if
(
args
.
jax_param_path
):
import_jax_weights_
(
import_jax_weights_
(
model
,
args
.
jax_param_
path
,
version
=
args
.
model_name
model
,
path
,
version
=
args
.
model_name
)
)
elif
(
args
.
openfold_checkpoint_path
):
model
=
model
.
to
(
args
.
model_device
)
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
yield
model
,
None
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
:
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
None
if
os
.
path
.
isdir
(
path
):
checkpoint_basename
=
os
.
path
.
splitext
(
checkpoint_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
openfold_checkpoint_
path
)
os
.
path
.
normpath
(
path
)
)
)
)[
0
]
)[
0
]
ckpt_path
=
os
.
path
.
join
(
ckpt_path
=
os
.
path
.
join
(
...
@@ -184,24 +232,30 @@ def main(args):
...
@@ -184,24 +232,30 @@ def main(args):
checkpoint_basename
+
".pt"
,
checkpoint_basename
+
".pt"
,
)
)
if
(
not
os
.
path
.
isfile
(
ckpt_path
)
)
:
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
openfold_checkpoint_path
,
args
.
openfold_checkpoint_path
,
ckpt_path
,
ckpt_path
,
)
)
else
:
else
:
ckpt_path
=
args
.
openfold_checkpoint_
path
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
model
=
model
.
to
(
args
.
model_device
)
yield
model
,
checkpoint_basename
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
"be specified."
)
)
model
=
model
.
to
(
args
.
model_device
)
def
main
(
args
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
model_name
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
...
@@ -222,7 +276,7 @@ def main(args):
...
@@ -222,7 +276,7 @@ def main(args):
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
if
(
args
.
use_precomputed_alignments
is
None
)
:
if
args
.
use_precomputed_alignments
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
...
@@ -231,49 +285,11 @@ def main(args):
...
@@ -231,49 +285,11 @@ def main(args):
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
batch
,
tag
,
feature_dict
=
generate_batch
(
fasta_file
,
args
.
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
)
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
(
len
(
seqs
)
==
1
):
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
batch
=
processed_feature_dict
out
=
run_model
(
model
,
batch
,
tag
,
args
)
out
=
run_model
(
model
,
batch
,
tag
,
args
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
...
@@ -285,7 +301,10 @@ def main(args):
...
@@ -285,7 +301,10 @@ def main(args):
)
)
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
if
(
args
.
output_postfix
is
not
None
):
if
model_version
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
model_version
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
...
@@ -295,7 +314,7 @@ def main(args):
...
@@ -295,7 +314,7 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
if
(
not
args
.
skip_relaxation
)
:
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
**
config
.
relax
,
...
@@ -304,7 +323,7 @@ def main(args):
...
@@ -304,7 +323,7 @@ def main(args):
# Relax the prediction.
# Relax the prediction.
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
)
:
if
"cuda"
in
args
.
model_device
:
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
...
@@ -318,7 +337,7 @@ def main(args):
...
@@ -318,7 +337,7 @@ def main(args):
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
fp
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
)
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment