Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f4043e1c
Commit
f4043e1c
authored
May 25, 2022
by
Gustaf Ahdritz
Browse files
Add AlphaFold-Gap inference
parent
89f05497
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
471 additions
and
99 deletions
+471
-99
openfold/config.py
openfold/config.py
+18
-9
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+145
-7
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+2
-2
openfold/np/protein.py
openfold/np/protein.py
+84
-14
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+5
-0
openfold/np/relax/relax.py
openfold/np/relax/relax.py
+5
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+212
-67
No files found.
openfold/config.py
View file @
f4043e1c
...
@@ -17,13 +17,15 @@ def model_config(name, train=False, low_prec=False):
...
@@ -17,13 +17,15 @@ def model_config(name, train=False, low_prec=False):
pass
pass
elif
name
==
"finetuning"
:
elif
name
==
"finetuning"
:
# AF2 Suppl. Table 4, "finetuning" setting
# AF2 Suppl. Table 4, "finetuning" setting
c
.
data
.
commo
n
.
max_extra_msa
=
5120
c
.
data
.
trai
n
.
max_extra_msa
=
5120
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
elif
name
==
"model_1"
:
elif
name
==
"model_1"
:
# AF2 Suppl. Table 5, Model 1.1.1
# AF2 Suppl. Table 5, Model 1.1.1
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
data
.
common
.
reduce_max_clusters_by_max_templates
=
True
c
.
data
.
common
.
reduce_max_clusters_by_max_templates
=
True
c
.
data
.
common
.
use_templates
=
True
c
.
data
.
common
.
use_templates
=
True
c
.
data
.
common
.
use_template_torsion_angles
=
True
c
.
data
.
common
.
use_template_torsion_angles
=
True
...
@@ -36,17 +38,20 @@ def model_config(name, train=False, low_prec=False):
...
@@ -36,17 +38,20 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
template
.
enabled
=
True
c
.
model
.
template
.
enabled
=
True
elif
name
==
"model_3"
:
elif
name
==
"model_3"
:
# AF2 Suppl. Table 5, Model 1.2.1
# AF2 Suppl. Table 5, Model 1.2.1
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
name
==
"model_4"
:
elif
name
==
"model_4"
:
# AF2 Suppl. Table 5, Model 1.2.2
# AF2 Suppl. Table 5, Model 1.2.2
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
name
==
"model_5"
:
elif
name
==
"model_5"
:
# AF2 Suppl. Table 5, Model 1.2.3
# AF2 Suppl. Table 5, Model 1.2.3
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
name
==
"model_1_ptm"
:
elif
name
==
"model_1_ptm"
:
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
data
.
common
.
reduce_max_clusters_by_max_templates
=
True
c
.
data
.
common
.
reduce_max_clusters_by_max_templates
=
True
c
.
data
.
common
.
use_templates
=
True
c
.
data
.
common
.
use_templates
=
True
c
.
data
.
common
.
use_template_torsion_angles
=
True
c
.
data
.
common
.
use_template_torsion_angles
=
True
...
@@ -61,12 +66,14 @@ def model_config(name, train=False, low_prec=False):
...
@@ -61,12 +66,14 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"model_3_ptm"
:
elif
name
==
"model_3_ptm"
:
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"model_4_ptm"
:
elif
name
==
"model_4_ptm"
:
c
.
data
.
common
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
predict
.
max_extra_msa
=
5120
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
...
@@ -184,7 +191,6 @@ config = mlc.ConfigDict(
...
@@ -184,7 +191,6 @@ config = mlc.ConfigDict(
"same_prob"
:
0.1
,
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
},
"max_extra_msa"
:
1024
,
"max_recycling_iters"
:
3
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"msa_cluster_features"
:
True
,
"reduce_msa_clusters_by_max_templates"
:
False
,
"reduce_msa_clusters_by_max_templates"
:
False
,
...
@@ -223,6 +229,7 @@ config = mlc.ConfigDict(
...
@@ -223,6 +229,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
128
,
"max_msa_clusters"
:
128
,
"max_extra_msa"
:
1024
,
"max_template_hits"
:
4
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"crop"
:
False
,
"crop"
:
False
,
...
@@ -235,6 +242,7 @@ config = mlc.ConfigDict(
...
@@ -235,6 +242,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
128
,
"max_msa_clusters"
:
128
,
"max_extra_msa"
:
1024
,
"max_template_hits"
:
4
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"crop"
:
False
,
"crop"
:
False
,
...
@@ -247,6 +255,7 @@ config = mlc.ConfigDict(
...
@@ -247,6 +255,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
128
,
"max_msa_clusters"
:
128
,
"max_extra_msa"
:
1024
,
"max_template_hits"
:
4
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"shuffle_top_k_prefiltered"
:
20
,
"shuffle_top_k_prefiltered"
:
20
,
...
@@ -262,7 +271,7 @@ config = mlc.ConfigDict(
...
@@ -262,7 +271,7 @@ config = mlc.ConfigDict(
"use_small_bfd"
:
False
,
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"data_loaders"
:
{
"batch_size"
:
1
,
"batch_size"
:
1
,
"num_workers"
:
16
,
"num_workers"
:
8
,
},
},
},
},
},
},
...
...
openfold/data/data_pipeline.py
View file @
f4043e1c
...
@@ -65,6 +65,47 @@ def make_template_features(
...
@@ -65,6 +65,47 @@ def make_template_features(
return
template_features
return
template_features
def
unify_template_features
(
template_feature_list
:
Sequence
[
FeatureDict
]
)
->
FeatureDict
:
out_dicts
=
[]
seq_lens
=
[
fd
[
"template_aatype"
].
shape
[
1
]
for
fd
in
template_feature_list
]
for
i
,
fd
in
enumerate
(
template_feature_list
):
out_dict
=
{}
n_templates
,
n_res
=
fd
[
"template_aatype"
].
shape
[:
2
]
for
k
,
v
in
fd
.
items
():
seq_keys
=
[
"template_aatype"
,
"template_all_atom_positions"
,
"template_all_atom_mask"
,
]
if
(
k
in
seq_keys
):
new_shape
=
list
(
v
.
shape
)
assert
(
new_shape
[
1
]
==
n_res
)
new_shape
[
1
]
=
sum
(
seq_lens
)
new_array
=
np
.
zeros
(
new_shape
,
dtype
=
v
.
dtype
)
if
(
k
==
"template_aatype"
):
new_array
[...,
residue_constants
.
HHBLITS_AA_TO_ID
[
'-'
]]
=
1
offset
=
sum
(
seq_lens
[:
i
])
new_array
[:,
offset
:
offset
+
seq_lens
[
i
]]
=
v
out_dict
[
k
]
=
new_array
else
:
out_dict
[
k
]
=
v
chain_indices
=
np
.
array
(
n_templates
*
[
i
])
out_dict
[
"template_chain_index"
]
=
chain_indices
out_dicts
.
append
(
out_dict
)
out_dict
=
{
k
:
np
.
concatenate
([
od
[
k
]
for
od
in
out_dicts
])
for
k
in
out_dicts
[
0
]
}
return
out_dict
def
make_sequence_features
(
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
)
->
FeatureDict
:
...
@@ -423,7 +464,6 @@ class DataPipeline:
...
@@ -423,7 +464,6 @@ class DataPipeline:
_alignment_index
:
Optional
[
Any
]
=
None
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
msa_data
=
{}
if
(
_alignment_index
is
not
None
):
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
...
@@ -506,14 +546,12 @@ class DataPipeline:
...
@@ -506,14 +546,12 @@ class DataPipeline:
return
all_hits
return
all_hits
def
_process_msa_feats
(
def
_get_msas
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
Mapping
[
str
,
Any
]
:
):
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
if
(
len
(
msa_data
)
==
0
):
if
(
len
(
msa_data
)
==
0
):
if
(
input_sequence
is
None
):
if
(
input_sequence
is
None
):
raise
ValueError
(
raise
ValueError
(
...
@@ -531,6 +569,17 @@ class DataPipeline:
...
@@ -531,6 +569,17 @@ class DataPipeline:
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
])
return
msas
,
deletion_matrices
def
_process_msa_feats
(
self
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
msas
,
deletion_matrices
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
_alignment_index
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
msas
,
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
deletion_matrices
=
deletion_matrices
,
...
@@ -685,3 +734,92 @@ class DataPipeline:
...
@@ -685,3 +734,92 @@ class DataPipeline:
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
def
process_multiseq_fasta
(
self
,
fasta_path
:
str
,
super_alignment_dir
:
str
,
ri_gap
:
int
=
200
,
)
->
FeatureDict
:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter. No templates.
"""
with
open
(
fasta_path
,
'r'
)
as
f
:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
# No whitespace allowed
input_descs
=
[
i
.
split
()[
0
]
for
i
in
input_descs
]
# Stitch all of the sequences together
input_sequence
=
''
.
join
(
input_seqs
)
input_description
=
'-'
.
join
(
input_descs
)
num_res
=
len
(
input_sequence
)
sequence_features
=
make_sequence_features
(
sequence
=
input_sequence
,
description
=
input_description
,
num_res
=
num_res
,
)
seq_lens
=
[
len
(
s
)
for
s
in
input_seqs
]
total_offset
=
0
for
sl
in
seq_lens
:
total_offset
+=
sl
sequence_features
[
"residue_index"
][
total_offset
:]
+=
ri_gap
msa_list
=
[]
deletion_mat_list
=
[]
for
seq
,
desc
in
zip
(
input_seqs
,
input_descs
):
alignment_dir
=
os
.
path
.
join
(
super_alignment_dir
,
desc
)
msas
,
deletion_mats
=
self
.
_get_msas
(
alignment_dir
,
seq
,
None
)
msa_list
.
append
(
msas
)
deletion_mat_list
.
append
(
deletion_mats
)
final_msa
=
[]
final_deletion_mat
=
[]
msa_it
=
enumerate
(
zip
(
msa_list
,
deletion_mat_list
))
for
i
,
(
msas
,
deletion_mats
)
in
msa_it
:
prec
,
post
=
sum
(
seq_lens
[:
i
]),
sum
(
seq_lens
[
i
+
1
:])
msas
=
[
[
prec
*
'-'
+
seq
+
post
*
'-'
for
seq
in
msa
]
for
msa
in
msas
]
deletion_mats
=
[
[
prec
*
[
0
]
+
dml
+
post
*
[
0
]
for
dml
in
deletion_mat
]
for
deletion_mat
in
deletion_mats
]
assert
(
len
(
msas
[
0
][
-
1
])
==
len
(
input_sequence
))
final_msa
.
extend
(
msas
)
final_deletion_mat
.
extend
(
deletion_mats
)
msa_features
=
make_msa_features
(
msas
=
final_msa
,
deletion_matrices
=
final_deletion_mat
,
)
template_feature_list
=
[]
for
seq
,
desc
in
zip
(
input_seqs
,
input_descs
):
alignment_dir
=
os
.
path
.
join
(
super_alignment_dir
,
desc
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
=
None
)
template_features
=
make_template_features
(
seq
,
hits
,
self
.
template_featurizer
,
)
template_feature_list
.
append
(
template_features
)
template_features
=
unify_template_features
(
template_feature_list
)
return
{
**
sequence_features
,
**
msa_features
,
**
template_features
,
}
openfold/data/input_pipeline.py
View file @
f4043e1c
...
@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common
_cfg
.
max_extra_msa
max_extra_msa
=
mode
_cfg
.
max_extra_msa
msa_seed
=
None
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
if
(
not
common_cfg
.
resample_msa_in_recycling
):
...
@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms
.
make_fixed_size
(
data_transforms
.
make_fixed_size
(
crop_feats
,
crop_feats
,
pad_msa_clusters
,
pad_msa_clusters
,
common
_cfg
.
max_extra_msa
,
mode
_cfg
.
max_extra_msa
,
mode_cfg
.
crop_size
,
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
mode_cfg
.
max_templates
,
)
)
...
...
openfold/np/protein.py
View file @
f4043e1c
...
@@ -16,8 +16,9 @@
...
@@ -16,8 +16,9 @@
"""Protein data type."""
"""Protein data type."""
import
dataclasses
import
dataclasses
import
io
import
io
from
typing
import
Any
,
Mapping
,
Optional
from
typing
import
Any
,
Sequence
,
Mapping
,
Optional
import
re
import
re
import
string
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
Bio.PDB
import
PDBParser
from
Bio.PDB
import
PDBParser
...
@@ -52,6 +53,19 @@ class Protein:
...
@@ -52,6 +53,19 @@ class Protein:
# value.
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index
:
Optional
[
np
.
ndarray
]
=
None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark
:
Optional
[
str
]
=
None
# Templates used to generate this protein (prediction-only)
parents
:
Optional
[
Sequence
[
str
]]
=
None
# Chain corresponding to each parent
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
"""Takes a PDB string and constructs a Protein object.
...
@@ -188,6 +202,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -188,6 +202,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
)
def
get_pdb_headers
(
prot
:
Protein
,
chain_id
:
int
=
0
)
->
Sequence
[
str
]:
pdb_headers
=
[]
remark
=
prot
.
remark
if
(
remark
is
not
None
):
pdb_headers
.
append
(
f
"REMARK
{
remark
}
"
)
parents
=
prot
.
parents
parents_chain_index
=
prot
.
parents_chain_index
if
(
parents_chain_index
is
not
None
):
parents
=
[
p
for
i
,
p
in
zip
(
parents_chain_index
,
parents
)
if
i
==
chain_id
]
if
(
parents
is
None
or
len
(
parents
)
==
0
):
parents
=
[
"N/A"
]
pdb_headers
.
append
(
f
"PARENT
{
' '
.
join
(
parents
)
}
"
)
return
pdb_headers
def
to_pdb
(
prot
:
Protein
)
->
str
:
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
"""Converts a `Protein` instance to a PDB string.
...
@@ -208,15 +244,21 @@ def to_pdb(prot: Protein) -> str:
...
@@ -208,15 +244,21 @@ def to_pdb(prot: Protein) -> str:
atom_positions
=
prot
.
atom_positions
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
b_factors
=
prot
.
b_factors
chain_index
=
prot
.
chain_index
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
raise
ValueError
(
"Invalid aatypes."
)
pdb_lines
.
append
(
"MODEL 1"
)
headers
=
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_lines
.
extend
(
headers
)
n
=
aatype
.
shape
[
0
]
atom_index
=
1
atom_index
=
1
chain_id
=
"A"
prev_chain_index
=
0
chain_tags
=
string
.
ascii_uppercase
# Add all atom sites.
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]
):
for
i
in
range
(
n
):
res_name_3
=
res_1to3
(
aatype
[
i
])
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
...
@@ -233,10 +275,15 @@ def to_pdb(prot: Protein) -> str:
...
@@ -233,10 +275,15 @@ def to_pdb(prot: Protein) -> str:
0
0
]
# Protein supports only C, N, O, S, this works.
]
# Protein supports only C, N, O, S, this works.
charge
=
""
charge
=
""
chain_tag
=
"A"
if
(
chain_index
is
not
None
):
chain_tag
=
chain_tags
[
chain_index
[
i
]]
# PDB is a columnar format, every space matters here!
# PDB is a columnar format, every space matters here!
atom_line
=
(
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_
id
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_
tag
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
...
@@ -245,14 +292,27 @@ def to_pdb(prot: Protein) -> str:
...
@@ -245,14 +292,27 @@ def to_pdb(prot: Protein) -> str:
pdb_lines
.
append
(
atom_line
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
atom_index
+=
1
should_terminate
=
(
i
==
n
-
1
)
if
(
chain_index
is
not
None
):
if
(
i
!=
n
-
1
and
chain_index
[
i
+
1
]
!=
prev_chain_index
):
should_terminate
=
True
prev_chain_index
=
chain_index
[
i
+
1
]
if
(
should_terminate
):
# Close the chain.
# Close the chain.
chain_end
=
"TER"
chain_end
=
"TER"
chain_termination_line
=
(
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
"
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
f
"
{
res_1to3
(
aatype
[
i
]):
>
3
}
"
f
"
{
chain_tag
:
>
1
}{
residue_index
[
i
]:
>
4
}
"
)
)
pdb_lines
.
append
(
chain_termination_line
)
pdb_lines
.
append
(
chain_termination_line
)
pdb_lines
.
append
(
"ENDMDL"
)
atom_index
+=
1
if
(
i
!=
n
-
1
):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines
.
extend
(
get_pdb_headers
(
prot
,
prev_chain_index
))
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
pdb_lines
.
append
(
""
)
...
@@ -279,6 +339,10 @@ def from_prediction(
...
@@ -279,6 +339,10 @@ def from_prediction(
features
:
FeatureDict
,
features
:
FeatureDict
,
result
:
ModelOutput
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
chain_index
:
Optional
[
np
.
ndarray
]
=
None
,
remark
:
Optional
[
str
]
=
None
,
parents
:
Optional
[
Sequence
[
str
]]
=
None
,
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
)
->
Protein
:
)
->
Protein
:
"""Assembles a protein from a prediction.
"""Assembles a protein from a prediction.
...
@@ -286,7 +350,9 @@ def from_prediction(
...
@@ -286,7 +350,9 @@ def from_prediction(
features: Dictionary holding model inputs.
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
Returns:
A protein instance.
A protein instance.
"""
"""
...
@@ -299,4 +365,8 @@ def from_prediction(
...
@@ -299,4 +365,8 @@ def from_prediction(
atom_mask
=
result
[
"final_atom_mask"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
features
[
"residue_index"
]
+
1
,
b_factors
=
b_factors
,
b_factors
=
b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
parents
,
parents_chain_index
=
parents_chain_index
,
)
)
openfold/np/relax/amber_minimize.py
View file @
f4043e1c
...
@@ -192,6 +192,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
...
@@ -192,6 +192,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
if
checks
:
if
checks
:
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_string
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
pdb_string
])
return
pdb_string
return
pdb_string
...
...
openfold/np/relax/relax.py
View file @
f4043e1c
...
@@ -87,4 +87,9 @@ class AmberRelaxation(object):
...
@@ -87,4 +87,9 @@ class AmberRelaxation(object):
violations
=
out
[
"structural_violations"
][
violations
=
out
[
"structural_violations"
][
"total_per_residue_violations_mask"
"total_per_residue_violations_mask"
]
]
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
min_pdb
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
min_pdb
])
return
min_pdb
,
debug_data
,
violations
return
min_pdb
,
debug_data
,
violations
run_pretrained_openfold.py
View file @
f4043e1c
...
@@ -21,6 +21,9 @@ import numpy as np
...
@@ -21,6 +21,9 @@ import numpy as np
import
os
import
os
import
pickle
import
pickle
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
import
random
import
random
import
sys
import
sys
import
time
import
time
...
@@ -42,12 +45,160 @@ from openfold.utils.tensor_utils import (
...
@@ -42,12 +45,160 @@ from openfold.utils.tensor_utils import (
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
logging
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
def
run_model
(
model
,
batch
,
tag
,
args
):
logging
.
info
(
"Executing model..."
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
batch
.
items
()
}
# Disable templates if there aren't any in the batch
model
.
config
.
template
.
enabled
=
any
([
"template_"
in
k
for
k
in
batch
])
logging
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
logging
.
info
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
args
):
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
(
feature_processor
.
config
.
common
.
use_templates
):
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
(
"template_chain_index"
in
feature_dict
):
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
args
.
model_name
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
np
.
arange
(
ri
.
shape
[
0
]))
/
args
.
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
np
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
(
c
!=
cur_chain
):
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
args
.
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
main
(
args
):
def
main
(
args
):
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# Prep the model
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
#script_preset_(model)
if
(
args
.
jax_param_path
):
import_jax_weights_
(
model
,
args
.
jax_param_path
,
version
=
args
.
model_name
)
elif
(
args
.
openfold_checkpoint_path
):
if
(
os
.
path
.
isdir
(
args
.
openfold_checkpoint_path
)):
checkpoint_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
openfold_checkpoint_path
)
)
)[
0
]
ckpt_path
=
os
.
path
.
join
(
args
.
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
(
not
os
.
path
.
isfile
(
ckpt_path
)):
convert_zero_checkpoint_to_fp32_state_dict
(
args
.
openfold_checkpoint_path
,
ckpt_path
,
)
else
:
ckpt_path
=
args
.
openfold_checkpoint_path
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
...
@@ -77,6 +228,9 @@ def main(args):
...
@@ -77,6 +228,9 @@ def main(args):
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
prediction_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
...
@@ -88,81 +242,59 @@ def main(args):
...
@@ -88,81 +242,59 @@ def main(args):
][
1
:]
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
assert
len
(
seqs
)
==
1
,
"Input FASTAs may only contain one sequence"
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
tag
,
seq
=
tags
[
0
],
seqs
[
0
]
assert
len
(
tags
)
==
len
(
set
(
tags
)),
"All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
if
(
len
(
seqs
)
==
1
):
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
logging
.
info
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
feature_dict
=
data_processor
.
process_fasta
(
if
not
os
.
path
.
exists
(
local_alignment_dir
):
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
)
alignment_runner
.
run
(
else
:
fasta_path
,
local_alignment_dir
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
)
# Remove temporary FASTA file
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
os
.
remove
(
tmp_
fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
feature_dict
,
mode
=
'predict'
,
)
)
logging
.
info
(
"Executing model..."
)
batch
=
processed_feature_dict
batch
=
processed_feature_dict
with
torch
.
no_grad
():
out
=
run_model
(
model
,
batch
,
tag
,
args
)
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
batch
.
items
()
}
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
logging
.
info
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
plddt
=
out
[
"plddt"
]
unrelaxed_protein
=
prep_output
(
mean_plddt
=
np
.
mean
(
plddt
)
out
,
batch
,
feature_dict
,
feature_processor
,
args
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
)
unrelaxed_protein
=
protein
.
from_prediction
(
output_name
=
f
'
{
tag
}
_
{
args
.
model_name
}
'
features
=
batch
,
if
(
args
.
output_postfix
is
not
None
):
result
=
out
,
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
b_factors
=
plddt_b_factors
)
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
unrelaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model
_name
}
_unrelaxed.pdb'
prediction_dir
,
f
'
{
output
_name
}
_unrelaxed.pdb'
)
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
p
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
p
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
if
(
not
args
.
skip_relaxation
):
if
(
not
args
.
skip_relaxation
):
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
...
@@ -182,14 +314,14 @@ def main(args):
...
@@ -182,14 +314,14 @@ def main(args):
# Save the relaxed PDB.
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model
_name
}
_relaxed.pdb'
prediction_dir
,
f
'
{
output
_name
}
_relaxed.pdb'
)
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
with
open
(
relaxed_output_path
,
'w'
)
as
f
p
:
f
.
write
(
relaxed_pdb_str
)
f
p
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
):
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model
_name
}
_output_dict.pkl'
args
.
output_dir
,
f
'
{
output
_name
}
_output_dict.pkl'
)
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
@@ -224,10 +356,15 @@ if __name__ == "__main__":
...
@@ -224,10 +356,15 @@ if __name__ == "__main__":
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--param_path"
,
type
=
str
,
default
=
None
,
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to model parameters. If None, parameters are selected
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
automatically according to the model name from
is also None, parameters are selected automatically according to
openfold/resources/params"""
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--openfold_checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
...
@@ -241,17 +378,25 @@ if __name__ == "__main__":
...
@@ -241,17 +378,25 @@ if __name__ == "__main__":
"--preset"
,
type
=
str
,
default
=
'full_dbs'
,
"--preset"
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
)
)
parser
.
add_argument
(
"--output_postfix"
,
type
=
str
,
default
=
None
,
help
=
"""Postfix for output prediction filenames"""
)
parser
.
add_argument
(
parser
.
add_argument
(
"--data_random_seed"
,
type
=
str
,
default
=
None
"--data_random_seed"
,
type
=
str
,
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
)
)
parser
.
add_argument
(
"--multimer_ri_gap"
,
type
=
int
,
default
=
200
,
help
=
"""Residue index offset between multiple sequences, if provided"""
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
args
.
param_path
is
None
):
if
(
args
.
jax_
param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
param_path
=
os
.
path
.
join
(
args
.
jax_
param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
model_name
+
".npz"
"params_"
+
args
.
model_name
+
".npz"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment