Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
99cdb062
Commit
99cdb062
authored
Jun 22, 2022
by
Gustaf Ahdritz
Browse files
Fix even more undefined references
parent
b1e4dc52
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
13 deletions
+10
-13
run_pretrained_openfold.py
run_pretrained_openfold.py
+10
-13
No files found.
run_pretrained_openfold.py
View file @
99cdb062
...
@@ -71,11 +71,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -71,11 +71,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
alignment_runner
.
run
(
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
tmp_
fasta_path
,
local_alignment_dir
)
)
# Remove temporary FASTA file
# Remove temporary FASTA file
...
@@ -133,7 +132,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -133,7 +132,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
remark
=
', '
.
join
([
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
args
.
model_name
}
"
,
f
"config_preset=
{
args
.
config_preset
}
"
,
])
])
# For multi-chain FASTAs
# For multi-chain FASTAs
...
@@ -167,16 +166,16 @@ def main(args):
...
@@ -167,16 +166,16 @@ def main(args):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# Prep the model
# Prep the model
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
config_preset
)
logger
.
info
(
f
"Using config preset
{
args
.
model_name
}
..."
)
logger
.
info
(
f
"Using config preset
{
args
.
config_preset
}
..."
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
if
(
args
.
jax_param_path
):
if
(
args
.
jax_param_path
):
import_jax_weights_
(
import_jax_weights_
(
model
,
args
.
jax_param_path
,
version
=
args
.
model_name
model
,
args
.
jax_param_path
,
version
=
args
.
config_preset
)
)
logger
.
info
(
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
...
@@ -234,8 +233,6 @@ def main(args):
...
@@ -234,8 +233,6 @@ def main(args):
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
data_processor
=
data_pipeline
.
DataPipeline
(
data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
)
)
...
@@ -271,7 +268,7 @@ def main(args):
...
@@ -271,7 +268,7 @@ def main(args):
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
tag
=
'-'
.
join
(
tags
)
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
(
args
.
output_postfix
is
not
None
):
if
(
args
.
output_postfix
is
not
None
):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
...
@@ -322,9 +319,9 @@ def main(args):
...
@@ -322,9 +319,9 @@ def main(args):
out
,
batch
,
feature_dict
,
feature_processor
,
args
out
,
batch
,
feature_dict
,
feature_processor
,
args
)
)
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
(
args
.
output_postfix
is
not
None
):
if
(
args
.
output_postfix
is
not
None
):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
_
{
tag_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
unrelaxed_output_path
=
os
.
path
.
join
(
...
@@ -394,7 +391,7 @@ if __name__ == "__main__":
...
@@ -394,7 +391,7 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")"""
device name is accepted (e.g. "cpu", "cuda:0")"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
model_name
"
,
type
=
str
,
default
=
"model_1"
,
"--
config_preset
"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
)
...
@@ -441,7 +438,7 @@ if __name__ == "__main__":
...
@@ -441,7 +438,7 @@ if __name__ == "__main__":
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
jax_param_path
=
os
.
path
.
join
(
args
.
jax_param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
model_name
+
".npz"
"params_"
+
args
.
config_preset
+
".npz"
)
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment