Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
89dee905
Unverified
Commit
89dee905
authored
Jun 22, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Jun 22, 2022
Browse files
Merge pull request #117 from CyrusBiotechnology/run-multiple-models
Use multiple models for inference
parents
a48860cb
a2ab7ab7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
169 additions
and
148 deletions
+169
-148
run_pretrained_openfold.py
run_pretrained_openfold.py
+169
-148
tests/test_data/sample_feats.pickle.gz
tests/test_data/sample_feats.pickle.gz
+0
-0
No files found.
run_pretrained_openfold.py
View file @
89dee905
...
@@ -19,6 +19,7 @@ import gc
...
@@ -19,6 +19,7 @@ import gc
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
copy
import
deepcopy
import
pickle
import
pickle
from
pytorch_lightning.utilities.deepspeed
import
(
from
pytorch_lightning.utilities.deepspeed
import
(
...
@@ -58,10 +59,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -58,10 +59,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
(
args
.
use_precomputed_alignments
is
None
):
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
if
not
os
.
path
.
exists
(
local_alignment_dir
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
...
@@ -161,69 +162,125 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -161,69 +162,125 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return
unrelaxed_protein
return
unrelaxed_protein
def
main
(
args
):
def
generate_batch
(
fasta_file
,
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
,
prediction_dir
):
# Create the output directory
with
open
(
os
.
path
.
join
(
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
data
=
fp
.
read
(
)
# Prep the model
lines
=
[
config
=
model_config
(
args
.
config_preset
)
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
logger
.
info
(
f
"Using config preset
{
args
.
config_preset
}
..."
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
model
=
AlphaFold
(
config
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
model
=
model
.
eval
()
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
if
os
.
path
.
exists
(
unrelaxed_output_path
):
return
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
if
(
args
.
jax_param_path
):
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
import_jax_weights_
(
feature_dict
=
data_processor
.
process_fasta
(
model
,
args
.
jax_param_path
,
version
=
args
.
config_preset
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
)
logger
.
info
(
else
:
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
)
elif
(
args
.
openfold_checkpoint_path
):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
# Remove temporary FASTA file
# A DeepSpeed checkpoint
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
return
processed_feature_dict
,
tag
,
feature_dict
def
load_models_from_command_line
(
args
,
config
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
if
args
.
jax_param_path
:
for
path
in
args
.
jax_param_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
args
.
model_name
)
model
=
model
.
to
(
args
.
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
)
yield
model
,
None
if
args
.
openfold_checkpoint_path
:
for
path
in
args
.
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
os
.
path
.
splitext
(
checkpoint_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
openfold_checkpoint_
path
)
os
.
path
.
normpath
(
path
)
)
)
)[
0
]
)[
0
]
ckpt_path
=
os
.
path
.
join
(
if
os
.
path
.
isdir
(
path
):
args
.
output_dir
,
# A DeepSpeed checkpoint
checkpoint_basename
+
".pt"
,
ckpt_path
=
os
.
path
.
join
(
)
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
if
(
not
os
.
path
.
isfile
(
ckpt_path
)):
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
openfold_checkpoint_path
,
ckpt_path
,
)
)
d
=
torch
.
load
(
ckpt_path
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
convert_zero_checkpoint_to_fp32_state_dict
(
else
:
path
,
# A checkpoint from the public release, which only contains EMA
ckpt_path
,
# params
)
ckpt_path
=
args
.
openfold_checkpoint_path
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
if
(
"ema"
in
d
):
ckpt_path
=
path
# The public weights have had this done to them already
d
=
torch
.
load
(
ckpt_path
)
d
=
d
[
"ema"
][
"params"
]
if
(
"ema"
in
d
):
model
.
load_state_dict
(
d
)
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
logger
.
info
(
model
.
load_state_dict
(
d
)
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
model
=
model
.
to
(
args
.
model_device
)
)
logger
.
info
(
else
:
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
)
yield
model
,
checkpoint_basename
if
not
args
.
jax_param_path
and
not
args
.
openfold_checkpoint_path
:
raise
ValueError
(
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
"be specified."
)
)
model
=
model
.
to
(
args
.
model_device
)
def
main
(
args
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
...
@@ -244,128 +301,92 @@ def main(args):
...
@@ -244,128 +301,92 @@ def main(args):
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
if
(
args
.
use_precomputed_alignments
is
None
)
:
if
args
.
use_precomputed_alignments
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
(
args
.
output_postfix
is
not
None
):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
if
(
os
.
path
.
exists
(
unrelaxed_output_path
)):
continue
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
batch_data
=
generate_batch
(
fasta_file
,
args
.
fasta_dir
,
alignment_dir
,
data_processor
,
feature_processor
,
prediction_dir
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
batch_data
is
None
:
if
(
len
(
seqs
)
==
1
):
# this file has already been processed
seq
=
seqs
[
0
]
continue
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
# Remove temporary FASTA file
batch
,
tag
,
feature_dict
=
batch_data
os
.
remove
(
tmp_fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
for
model
,
model_version
in
load_models_from_command_line
(
args
,
config
):
feature_dict
,
mode
=
'predict'
,
)
batch
=
processed_feature_dict
working_batch
=
deepcopy
(
batch
)
out
=
run_model
(
model
,
batch
,
tag
,
args
)
out
=
run_model
(
model
,
working_
batch
,
tag
,
args
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
working_batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
working_batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
)
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
unrelaxed_protein
=
prep_output
(
if
(
args
.
output_postfix
is
not
None
):
out
,
working_batch
,
feature_dict
,
feature_processor
,
args
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
)
# Save the unrelaxed PDB.
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
model_version
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
model_version
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
if
(
not
args
.
skip_relaxation
):
# Save the unrelaxed PDB.
amber_relaxer
=
relax
.
AmberRelaxation
(
unrelaxed_output_path
=
os
.
path
.
join
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
**
config
.
relax
,
)
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logger
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_relaxed.pdb'
)
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
args
.
model_device
:
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logger
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
if
(
args
.
save_outputs
)
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
)
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
tests/test_data/sample_feats.pickle.gz
deleted
100755 → 0
View file @
a48860cb
File deleted
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment