Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9dee6caa
Commit
9dee6caa
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Update inference script
parent
7642bef9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
11 deletions
+28
-11
run_pretrained_openfold.py
run_pretrained_openfold.py
+28
-11
No files found.
run_pretrained_openfold.py
View file @
9dee6caa
...
@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
...
@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
alignment_runner
.
run
(
alignment_runner
.
run
(
tmp_
fasta_path
,
local_alignment_dir
fasta_path
,
local_alignment_dir
)
)
# Remove temporary FASTA file
# Remove temporary FASTA file
...
@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args):
...
@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args):
}
}
# Disable templates if there aren't any in the batch
# Disable templates if there aren't any in the batch
model
.
config
.
template
.
enabled
=
any
([
model
.
config
.
template
.
enabled
=
model
.
config
.
template
.
enabled
and
any
([
"template_"
in
k
for
k
in
batch
"template_"
in
k
for
k
in
batch
])
])
...
@@ -165,6 +164,7 @@ def main(args):
...
@@ -165,6 +164,7 @@ def main(args):
# Prep the model
# Prep the model
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
...
@@ -174,6 +174,7 @@ def main(args):
...
@@ -174,6 +174,7 @@ def main(args):
)
)
elif
(
args
.
openfold_checkpoint_path
):
elif
(
args
.
openfold_checkpoint_path
):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
# A DeepSpeed checkpoint
checkpoint_basename
=
os
.
path
.
splitext
(
checkpoint_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
openfold_checkpoint_path
)
os
.
path
.
normpath
(
args
.
openfold_checkpoint_path
)
...
@@ -193,6 +194,8 @@ def main(args):
...
@@ -193,6 +194,8 @@ def main(args):
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
else
:
# A checkpoint from the public release, which only contains EMA
# params
ckpt_path
=
args
.
openfold_checkpoint_path
ckpt_path
=
args
.
openfold_checkpoint_path
d
=
torch
.
load
(
ckpt_path
)
d
=
torch
.
load
(
ckpt_path
)
...
@@ -218,6 +221,8 @@ def main(args):
...
@@ -218,6 +221,8 @@ def main(args):
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
data_processor
=
data_pipeline
.
DataPipeline
(
data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
)
)
...
@@ -249,9 +254,21 @@ def main(args):
...
@@ -249,9 +254,21 @@ def main(args):
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
#
assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
if
(
args
.
output_postfix
is
not
None
):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
prediction_dir
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
if
(
os
.
path
.
exists
(
unrelaxed_output_path
)):
continue
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
...
@@ -293,7 +310,7 @@ def main(args):
...
@@ -293,7 +310,7 @@ def main(args):
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
if
(
args
.
output_postfix
is
not
None
):
if
(
args
.
output_postfix
is
not
None
):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
_
{
tag_postfix
}
'
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
unrelaxed_output_path
=
os
.
path
.
join
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment