Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ce27a6ca
Commit
ce27a6ca
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Add logging to inference script
parent
f38f346d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
6 deletions
+26
-6
run_pretrained_openfold.py
run_pretrained_openfold.py
+26
-6
No files found.
run_pretrained_openfold.py
View file @
ce27a6ca
...
@@ -45,6 +45,11 @@ from openfold.utils.tensor_utils import (
...
@@ -45,6 +45,11 @@ from openfold.utils.tensor_utils import (
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
...
@@ -53,7 +58,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -53,7 +58,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
(
args
.
use_precomputed_alignments
is
None
):
logg
ing
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
logg
er
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
if
not
os
.
path
.
exists
(
local_alignment_dir
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
...
@@ -78,7 +83,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -78,7 +83,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
def
run_model
(
model
,
batch
,
tag
,
args
):
def
run_model
(
model
,
batch
,
tag
,
args
):
logging
.
info
(
"Executing model..."
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch
=
{
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
...
@@ -90,10 +94,10 @@ def run_model(model, batch, tag, args):
...
@@ -90,10 +94,10 @@ def run_model(model, batch, tag, args):
"template_"
in
k
for
k
in
batch
"template_"
in
k
for
k
in
batch
])
])
logg
ing
.
info
(
f
"Running inference for
{
tag
}
..."
)
logg
er
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
out
=
model
(
batch
)
logg
ing
.
info
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
logg
er
.
info
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
return
out
return
out
...
@@ -165,6 +169,8 @@ def main(args):
...
@@ -165,6 +169,8 @@ def main(args):
# Prep the model
# Prep the model
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
logger
.
info
(
f
"Using config preset
{
args
.
model_name
}
..."
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
...
@@ -172,6 +178,9 @@ def main(args):
...
@@ -172,6 +178,9 @@ def main(args):
import_jax_weights_
(
import_jax_weights_
(
model
,
args
.
jax_param_path
,
version
=
args
.
model_name
model
,
args
.
jax_param_path
,
version
=
args
.
model_name
)
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
)
elif
(
args
.
openfold_checkpoint_path
):
elif
(
args
.
openfold_checkpoint_path
):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
# A DeepSpeed checkpoint
# A DeepSpeed checkpoint
...
@@ -204,6 +213,10 @@ def main(args):
...
@@ -204,6 +213,10 @@ def main(args):
d
=
d
[
"ema"
][
"params"
]
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
.
load_state_dict
(
d
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
args
.
openfold_checkpoint_path
}
..."
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"At least one of jax_param_path or openfold_checkpoint_path must "
...
@@ -238,6 +251,7 @@ def main(args):
...
@@ -238,6 +251,7 @@ def main(args):
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
logger
.
info
(
f
"Using precomputed alignments at
{
alignment_dir
}
..."
)
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
...
@@ -319,6 +333,8 @@ def main(args):
...
@@ -319,6 +333,8 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
(
not
args
.
skip_relaxation
):
if
(
not
args
.
skip_relaxation
):
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
...
@@ -326,6 +342,7 @@ def main(args):
...
@@ -326,6 +342,7 @@ def main(args):
)
)
# Relax the prediction.
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
if
(
"cuda"
in
args
.
model_device
):
...
@@ -333,7 +350,7 @@ def main(args):
...
@@ -333,7 +350,7 @@ def main(args):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logg
ing
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
logg
er
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
relaxed_output_path
=
os
.
path
.
join
(
...
@@ -342,6 +359,8 @@ def main(args):
...
@@ -342,6 +359,8 @@ def main(args):
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
if
(
args
.
save_outputs
):
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
args
.
output_dir
,
f
'
{
output_name
}
_output_dict.pkl'
...
@@ -349,6 +368,7 @@ def main(args):
...
@@ -349,6 +368,7 @@ def main(args):
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment