Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
15a8c321
Commit
15a8c321
authored
Jun 16, 2022
by
Sam DeLuca
Browse files
refactor to allow multiple models to run
parent
4bfd5bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
134 additions
and
115 deletions
+134
-115
run_pretrained_openfold.py
run_pretrained_openfold.py
+134
-115
No files found.
run_pretrained_openfold.py
View file @
15a8c321
...
@@ -159,49 +159,103 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -159,49 +159,103 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return
unrelaxed_protein
return
unrelaxed_protein
def
main
(
args
):
def
generate_batch
(
fasta_file
,
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
):
# Create the output directory
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
data
=
fp
.
read
(
)
# Prep the model
lines
=
[
config
=
model_config
(
args
.
model_name
)
l
.
replace
(
'
\n
'
,
''
)
model
=
AlphaFold
(
config
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
model
=
model
.
eval
()
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
if
(
args
.
jax_param_path
):
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
import_jax_weights_
(
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
model
,
args
.
jax_param_path
,
version
=
args
.
model_name
tag
=
'-'
.
join
(
tags
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
)
elif
(
args
.
openfold_checkpoint_path
):
else
:
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
checkpoint_basename
=
os
.
path
.
splitext
(
fp
.
write
(
os
.
path
.
basename
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
os
.
path
.
normpath
(
args
.
openfold_checkpoint_path
)
)
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
)[
0
]
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
ckpt_path
=
os
.
path
.
join
(
)
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
# Remove temporary FASTA file
)
os
.
remove
(
tmp_fasta_path
)
if
(
not
os
.
path
.
isfile
(
ckpt_path
)):
processed_feature_dict
=
feature_processor
.
process_features
(
convert_zero_checkpoint_to_fp32_state_dict
(
feature_dict
,
mode
=
'predict'
,
args
.
openfold_checkpoint_path
,
)
ckpt_path
,
return
processed_feature_dict
,
tag
,
feature_dict
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
args
.
jax_param_path
:
for
path
in
args
.
jax_param_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
args
.
model_name
)
model
=
model
.
to
(
args
.
model_device
)
yield
model
,
None
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
:
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
None
if
os
.
path
.
isdir
(
path
):
checkpoint_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
path
)
)
)[
0
]
ckpt_path
=
os
.
path
.
join
(
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
)
)
else
:
ckpt_path
=
args
.
openfold_checkpoint_path
d
=
torch
.
load
(
ckpt_path
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
openfold_checkpoint_path
,
ckpt_path
,
)
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
model
=
model
.
to
(
args
.
model_device
)
yield
model
,
checkpoint_basename
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
"be specified."
)
)
model
=
model
.
to
(
args
.
model_device
)
def
main
(
args
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
model_name
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
...
@@ -222,7 +276,7 @@ def main(args):
...
@@ -222,7 +276,7 @@ def main(args):
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
if
(
args
.
use_precomputed_alignments
is
None
)
:
if
args
.
use_precomputed_alignments
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
...
@@ -231,99 +285,64 @@ def main(args):
...
@@ -231,99 +285,64 @@ def main(args):
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
batch
,
tag
,
feature_dict
=
generate_batch
(
fasta_file
,
args
.
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
)
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
out
=
run_model
(
model
,
batch
,
tag
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
# Toss out the recycling dimensions --- we don't need them anymore
if
(
len
(
seqs
)
==
1
):
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
seq
=
seqs
[
0
]
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
unrelaxed_protein
=
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
batch
=
processed_feature_dict
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
out
=
run_model
(
model
,
batch
,
tag
,
args
)
# Toss out the recycling dimensions --- we don't need them anymore
if
model_version
is
not
None
:
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
output_name
=
f
'
{
output_name
}
_
{
model_version
}
'
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
unrelaxed_protein
=
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
)
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
# Save the unrelaxed PDB.
if
(
args
.
output_postfix
is
not
None
):
unrelaxed_output_path
=
os
.
path
.
join
(
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
# Save the unrelaxed PDB.
if
not
args
.
skip_relaxation
:
unrelaxed_output_path
=
os
.
path
.
join
(
amber_relaxer
=
relax
.
AmberRelaxation
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
)
**
config
.
relax
,
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
)
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
if
(
not
args
.
skip_relaxation
):
# Relax the prediction.
amber_relaxer
=
relax
.
AmberRelaxation
(
t
=
time
.
perf_counter
()
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
**
config
.
relax
,
if
"cuda"
in
args
.
model_device
:
)
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
# Relax the prediction.
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
t
=
time
.
perf_counter
()
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
logging
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
# Save the relaxed PDB.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_output_path
=
os
.
path
.
join
(
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
prediction_dir
,
f
'
{
output_name
}
_relaxed.pdb'
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
)
logging
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
)
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
)
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment