Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d48c052c
Commit
d48c052c
authored
Oct 15, 2021
by
Gustaf Ahdritz
Browse files
Add training parsers
parent
eeda001c
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1676 additions
and
919 deletions
+1676
-919
openfold/config.py
openfold/config.py
+281
-233
openfold/features/data_pipeline.py
openfold/features/data_pipeline.py
+336
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+590
-47
openfold/features/feature_pipeline.py
openfold/features/feature_pipeline.py
+48
-17
openfold/features/input_pipeline.py
openfold/features/input_pipeline.py
+53
-24
openfold/features/mmcif_parsing.py
openfold/features/mmcif_parsing.py
+77
-2
openfold/features/np/hhsearch.py
openfold/features/np/hhsearch.py
+5
-1
openfold/features/np/jackhmmer.py
openfold/features/np/jackhmmer.py
+1
-3
openfold/features/np/utils.py
openfold/features/np/utils.py
+7
-0
openfold/features/templates.py
openfold/features/templates.py
+43
-53
openfold/model/model.py
openfold/model/model.py
+28
-29
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+9
-8
openfold/utils/feats.py
openfold/utils/feats.py
+6
-386
openfold/utils/loss.py
openfold/utils/loss.py
+8
-35
run_pretrained_openfold.py
run_pretrained_openfold.py
+23
-80
scripts/build_deepspeed_config.py
scripts/build_deepspeed_config.py
+1
-0
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+107
-0
scripts/utils.py
scripts/utils.py
+48
-0
tests/compare_utils.py
tests/compare_utils.py
+3
-0
tests/test_feats.py
tests/test_feats.py
+2
-1
No files found.
openfold/config.py
View file @
d48c052c
...
...
@@ -6,42 +6,42 @@ def set_inf(c, inf):
for
k
,
v
in
c
.
items
():
if
(
isinstance
(
v
,
mlc
.
ConfigDict
)):
set_inf
(
v
,
inf
)
elif
(
k
==
"
inf
"
):
elif
(
k
==
'
inf
'
):
c
[
k
]
=
inf
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
if
(
name
==
"
model_1
"
):
if
(
name
==
'
model_1
'
):
pass
elif
(
name
==
"
model_2
"
):
elif
(
name
==
'
model_2
'
):
pass
elif
(
name
==
"
model_3
"
):
elif
(
name
==
'
model_3
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_4
"
):
elif
(
name
==
'
model_4
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_5
"
):
elif
(
name
==
'
model_5
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_1_ptm
"
):
elif
(
name
==
'
model_1_ptm
'
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_2_ptm
"
):
elif
(
name
==
'
model_2_ptm
'
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_3_ptm
"
):
elif
(
name
==
'
model_3_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_4_ptm
"
):
elif
(
name
==
'
model_4_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_5_ptm
"
):
elif
(
name
==
'
model_5_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
else
:
raise
ValueError
(
"
Invalid model name
"
)
raise
ValueError
(
'
Invalid model name
'
)
if
(
train
):
c
.
globals
.
blocks_per_ckpt
=
1
...
...
@@ -65,6 +65,9 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
num_recycle
=
mlc
.
FieldReference
(
3
,
field_type
=
int
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
NUM_RES
=
'num residues placeholder'
NUM_MSA_SEQ
=
'msa placeholder'
...
...
@@ -74,29 +77,7 @@ NUM_TEMPLATES = 'num templates placeholder'
config
=
mlc
.
ConfigDict
({
'data'
:
{
'common'
:
{
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
3
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_masks'
,
# 'template_domain_names'
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'msa'
,
# 'sequence', #'domain_name',
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
True
,
},
'eval'
:
{
'batch_modes'
:
[(
'clamped'
,
0.9
),
(
'unclamped'
,
0.1
)],
'feat'
:
{
'aatype'
:
[
NUM_RES
],
'all_atom_mask'
:
[
NUM_RES
,
None
],
...
...
@@ -110,7 +91,7 @@ config = mlc.ConfigDict({
'atom14_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom37_atom_exists'
:
[
NUM_RES
,
None
],
'backbone_affine_mask'
:
[
NUM_RES
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
,
None
],
'bert_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'chi_angles'
:
[
NUM_RES
,
None
],
'chi_mask'
:
[
NUM_RES
,
None
],
...
...
@@ -125,266 +106,333 @@ config = mlc.ConfigDict({
'msa_row_mask'
:
[
NUM_MSA_SEQ
],
'pseudo_beta'
:
[
NUM_RES
,
None
],
'pseudo_beta_mask'
:
[
NUM_RES
],
'random_crop_to_size_seed'
:
[
None
],
'residue_index'
:
[
NUM_RES
],
'residx_atom14_to_atom37'
:
[
NUM_RES
,
None
],
'residx_atom37_to_atom14'
:
[
NUM_RES
,
None
],
'resolution'
:
[],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
'rigidgroups_group_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_group_is_ambiguous'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
'seq_length'
:
[],
'seq_mask'
:
[
NUM_RES
],
'target_feat'
:
[
NUM_RES
,
None
],
'template_aatype'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_all_atom_masks'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_all_atom_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_alt_torsion_angles_sin_cos'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_backbone_affine_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_backbone_affine_tensor'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_mask'
:
[
NUM_TEMPLATES
],
'template_pseudo_beta'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_pseudo_beta_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_sum_probs'
:
[
NUM_TEMPLATES
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
]
'template_torsion_angles_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_torsion_angles_sin_cos'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'use_clamped_fape'
:
[],
},
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
num_recycle
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_mask'
,
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
templates_enabled
,
'use_template_torsion_angles'
:
embed_template_torsion_angles
,
'supervised_features'
:
[
'all_atom_mask'
,
'all_atom_positions'
,
'resolution'
,
'use_clamped_fape'
,
],
},
'predict'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
False
,
'crop_size'
:
None
,
'supervised'
:
False
,
},
'eval'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
False
,
'crop_size'
:
None
,
'supervised'
:
True
,
},
'train'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
True
,
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
True
,
'crop_size'
:
256
,
'supervised'
:
True
,
},
'data_module'
:
{
'use_small_bfd'
:
False
,
'data_loaders'
:
{
'batch_size'
:
1
,
'num_workers'
:
1
,
},
}
},
# Recurring FieldReferences that can be changed globally here
"
globals
"
:
{
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
c_t
"
:
c_t
,
"
c_e
"
:
c_e
,
"
c_s
"
:
c_s
,
"
eps
"
:
eps
,
'
globals
'
:
{
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
c_t
'
:
c_t
,
'
c_e
'
:
c_e
,
'
c_s
'
:
c_s
,
'
eps
'
:
eps
,
},
"
model
"
:
{
"no_cycles"
:
4
,
"
_mask_trans
"
:
False
,
"
input_embedder
"
:
{
"
tf_dim
"
:
22
,
"
msa_dim
"
:
49
,
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
relpos_k
"
:
32
,
'
model
'
:
{
'num_recycle'
:
num_recycle
,
'
_mask_trans
'
:
False
,
'
input_embedder
'
:
{
'
tf_dim
'
:
22
,
'
msa_dim
'
:
49
,
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
relpos_k
'
:
32
,
},
"
recycling_embedder
"
:
{
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
min_bin
"
:
3.25
,
"
max_bin
"
:
20.75
,
"
no_bins
"
:
15
,
"
inf
"
:
1e8
,
'
recycling_embedder
'
:
{
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
min_bin
'
:
3.25
,
'
max_bin
'
:
20.75
,
'
no_bins
'
:
15
,
'
inf
'
:
1e8
,
},
"
template
"
:
{
"
distogram
"
:
{
"
min_bin
"
:
3.25
,
"
max_bin
"
:
50.75
,
"
no_bins
"
:
39
,
'
template
'
:
{
'
distogram
'
:
{
'
min_bin
'
:
3.25
,
'
max_bin
'
:
50.75
,
'
no_bins
'
:
39
,
},
"
template_angle_embedder
"
:
{
'
template_angle_embedder
'
:
{
# DISCREPANCY: c_in is supposed to be 51.
"
c_in
"
:
57
,
"
c_out
"
:
c_m
,
'
c_in
'
:
57
,
'
c_out
'
:
c_m
,
},
"
template_pair_embedder
"
:
{
"
c_in
"
:
88
,
"
c_out
"
:
c_t
,
'
template_pair_embedder
'
:
{
'
c_in
'
:
88
,
'
c_out
'
:
c_t
,
},
"
template_pair_stack
"
:
{
"
c_t
"
:
c_t
,
'
template_pair_stack
'
:
{
'
c_t
'
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"
c_hidden_tri_att
"
:
16
,
"
c_hidden_tri_mul
"
:
64
,
"
no_blocks
"
:
2
,
"
no_heads
"
:
4
,
"
pair_transition_n
"
:
2
,
"
dropout_rate
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
'
c_hidden_tri_att
'
:
16
,
'
c_hidden_tri_mul
'
:
64
,
'
no_blocks
'
:
2
,
'
no_heads
'
:
4
,
'
pair_transition_n
'
:
2
,
'
dropout_rate
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
},
"
template_pointwise_attention
"
:
{
"
c_t
"
:
c_t
,
"
c_z
"
:
c_z
,
'
template_pointwise_attention
'
:
{
'
c_t
'
:
c_t
,
'
c_z
'
:
c_z
,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"
c_hidden
"
:
16
,
"
no_heads
"
:
4
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
'
c_hidden
'
:
16
,
'
no_heads
'
:
4
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
},
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-6,
"
enabled
"
:
True
,
"
embed_angles
"
:
True
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-6,
'
enabled
'
:
templates_enabled
,
'
embed_angles
'
:
embed_template_torsion_angles
,
},
"
extra_msa
"
:
{
"
extra_msa_embedder
"
:
{
"
c_in
"
:
25
,
"
c_out
"
:
c_e
,
'
extra_msa
'
:
{
'
extra_msa_embedder
'
:
{
'
c_in
'
:
25
,
'
c_out
'
:
c_e
,
},
"
extra_msa_stack
"
:
{
"
c_m
"
:
c_e
,
"
c_z
"
:
c_z
,
"
c_hidden_msa_att
"
:
8
,
"
c_hidden_opm
"
:
32
,
"
c_hidden_mul
"
:
128
,
"
c_hidden_pair_att
"
:
32
,
"
no_heads_msa
"
:
8
,
"
no_heads_pair
"
:
4
,
"
no_blocks
"
:
4
,
"
transition_n
"
:
4
,
"
msa_dropout
"
:
0.15
,
"
pair_dropout
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-10,
'
extra_msa_stack
'
:
{
'
c_m
'
:
c_e
,
'
c_z
'
:
c_z
,
'
c_hidden_msa_att
'
:
8
,
'
c_hidden_opm
'
:
32
,
'
c_hidden_mul
'
:
128
,
'
c_hidden_pair_att
'
:
32
,
'
no_heads_msa
'
:
8
,
'
no_heads_pair
'
:
4
,
'
no_blocks
'
:
4
,
'
transition_n
'
:
4
,
'
msa_dropout
'
:
0.15
,
'
pair_dropout
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-10,
},
"
enabled
"
:
True
,
'
enabled
'
:
True
,
},
"
evoformer_stack
"
:
{
"
c_m
"
:
c_m
,
"
c_z
"
:
c_z
,
"
c_hidden_msa_att
"
:
32
,
"
c_hidden_opm
"
:
32
,
"
c_hidden_mul
"
:
128
,
"
c_hidden_pair_att
"
:
32
,
"
c_s
"
:
c_s
,
"
no_heads_msa
"
:
8
,
"
no_heads_pair
"
:
4
,
"
no_blocks
"
:
48
,
"
transition_n
"
:
4
,
"
msa_dropout
"
:
0.15
,
"
pair_dropout
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-10,
'
evoformer_stack
'
:
{
'
c_m
'
:
c_m
,
'
c_z
'
:
c_z
,
'
c_hidden_msa_att
'
:
32
,
'
c_hidden_opm
'
:
32
,
'
c_hidden_mul
'
:
128
,
'
c_hidden_pair_att
'
:
32
,
'
c_s
'
:
c_s
,
'
no_heads_msa
'
:
8
,
'
no_heads_pair
'
:
4
,
'
no_blocks
'
:
48
,
'
transition_n
'
:
4
,
'
msa_dropout
'
:
0.15
,
'
pair_dropout
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-10,
},
"
structure_module
"
:
{
"
c_s
"
:
c_s
,
"
c_z
"
:
c_z
,
"
c_ipa
"
:
16
,
"
c_resnet
"
:
128
,
"
no_heads_ipa
"
:
12
,
"
no_qk_points
"
:
4
,
"
no_v_points
"
:
8
,
"
dropout_rate
"
:
0.1
,
"
no_blocks
"
:
8
,
"
no_transition_layers
"
:
1
,
"
no_resnet_blocks
"
:
2
,
"
no_angles
"
:
7
,
"
trans_scale_factor
"
:
10
,
"
epsilon
"
:
eps
,
#1e-12,
"
inf
"
:
1e5
,
'
structure_module
'
:
{
'
c_s
'
:
c_s
,
'
c_z
'
:
c_z
,
'
c_ipa
'
:
16
,
'
c_resnet
'
:
128
,
'
no_heads_ipa
'
:
12
,
'
no_qk_points
'
:
4
,
'
no_v_points
'
:
8
,
'
dropout_rate
'
:
0.1
,
'
no_blocks
'
:
8
,
'
no_transition_layers
'
:
1
,
'
no_resnet_blocks
'
:
2
,
'
no_angles
'
:
7
,
'
trans_scale_factor
'
:
10
,
'
epsilon
'
:
eps
,
#1e-12,
'
inf
'
:
1e5
,
},
"
heads
"
:
{
"
lddt
"
:
{
"
no_bins
"
:
50
,
"
c_in
"
:
c_s
,
"
c_hidden
"
:
128
,
'
heads
'
:
{
'
lddt
'
:
{
'
no_bins
'
:
50
,
'
c_in
'
:
c_s
,
'
c_hidden
'
:
128
,
},
"
distogram
"
:
{
"
c_z
"
:
c_z
,
"
no_bins
"
:
aux_distogram_bins
,
'
distogram
'
:
{
'
c_z
'
:
c_z
,
'
no_bins
'
:
aux_distogram_bins
,
},
"
tm
"
:
{
"
c_z
"
:
c_z
,
"
no_bins
"
:
aux_distogram_bins
,
"
enabled
"
:
False
,
'
tm
'
:
{
'
c_z
'
:
c_z
,
'
no_bins
'
:
aux_distogram_bins
,
'
enabled
'
:
False
,
},
"
masked_msa
"
:
{
"
c_m
"
:
c_m
,
"
c_out
"
:
23
,
'
masked_msa
'
:
{
'
c_m
'
:
c_m
,
'
c_out
'
:
23
,
},
"
experimentally_resolved
"
:
{
"
c_s
"
:
c_s
,
"
c_out
"
:
37
,
'
experimentally_resolved
'
:
{
'
c_s
'
:
c_s
,
'
c_out
'
:
37
,
},
},
},
"
relax
"
:
{
"
max_iterations
"
:
0
,
# no max
"
tolerance
"
:
2.39
,
"
stiffness
"
:
10.0
,
"
max_outer_iterations
"
:
20
,
"
exclude_residues
"
:
[],
'
relax
'
:
{
'
max_iterations
'
:
0
,
# no max
'
tolerance
'
:
2.39
,
'
stiffness
'
:
10.0
,
'
max_outer_iterations
'
:
20
,
'
exclude_residues
'
:
[],
},
"
loss
"
:
{
"
distogram
"
:
{
"
min_bin
"
:
2.3125
,
"
max_bin
"
:
21.6875
,
"
no_bins
"
:
64
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
0.3
,
'
loss
'
:
{
'
distogram
'
:
{
'
min_bin
'
:
2.3125
,
'
max_bin
'
:
21.6875
,
'
no_bins
'
:
64
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
0.3
,
},
"
experimentally_resolved
"
:
{
"
eps
"
:
eps
,
#1e-8,
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
weight
"
:
0.
,
'
experimentally_resolved
'
:
{
'
eps
'
:
eps
,
#1e-8,
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
weight
'
:
0.
,
},
"
fape
"
:
{
"
backbone
"
:
{
"
clamp_distance
"
:
10.
,
"
loss_unit_distance
"
:
10.
,
"
weight
"
:
0.5
,
'
fape
'
:
{
'
backbone
'
:
{
'
clamp_distance
'
:
10.
,
'
loss_unit_distance
'
:
10.
,
'
weight
'
:
0.5
,
},
"
sidechain
"
:
{
"
clamp_distance
"
:
10.
,
"
length_scale
"
:
10.
,
"
weight
"
:
0.5
,
'
sidechain
'
:
{
'
clamp_distance
'
:
10.
,
'
length_scale
'
:
10.
,
'
weight
'
:
0.5
,
},
"
eps
"
:
1e-4
,
"
weight
"
:
1.0
,
'
eps
'
:
1e-4
,
'
weight
'
:
1.0
,
},
"
lddt
"
:
{
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
cutoff
"
:
15.
,
"
no_bins
"
:
50
,
"
eps
"
:
eps
,
#1e-10,
"
weight
"
:
0.01
,
'
lddt
'
:
{
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
cutoff
'
:
15.
,
'
no_bins
'
:
50
,
'
eps
'
:
eps
,
#1e-10,
'
weight
'
:
0.01
,
},
"
masked_msa
"
:
{
"
eps
"
:
eps
,
#1e-8,
"
weight
"
:
2.0
,
'
masked_msa
'
:
{
'
eps
'
:
eps
,
#1e-8,
'
weight
'
:
2.0
,
},
"
supervised_chi
"
:
{
"
chi_weight
"
:
0.5
,
"
angle_norm_weight
"
:
0.01
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
1.0
,
'
supervised_chi
'
:
{
'
chi_weight
'
:
0.5
,
'
angle_norm_weight
'
:
0.01
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
1.0
,
},
"
violation
"
:
{
"
violation_tolerance_factor
"
:
12.0
,
"
clash_overlap_tolerance
"
:
1.5
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
0.
,
'
violation
'
:
{
'
violation_tolerance_factor
'
:
12.0
,
'
clash_overlap_tolerance
'
:
1.5
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
0.
,
},
"
tm
"
:
{
"
max_bin
"
:
31
,
"
no_bins
"
:
64
,
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
eps
"
:
eps
,
#1e-8,
"
weight
"
:
0.
,
'
tm
'
:
{
'
max_bin
'
:
31
,
'
no_bins
'
:
64
,
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
eps
'
:
eps
,
#1e-8,
'
weight
'
:
0.
,
},
"eps"
:
eps
,
'eps'
:
eps
,
},
'ema'
:
{
'decay'
:
0.999
},
})
openfold/features/
np/
data_pipeline.py
→
openfold/features/data_pipeline.py
View file @
d48c052c
import
os
import
datetime
import
numpy
as
np
from
typing
import
Mapping
,
Optional
,
Sequence
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
from
openfold.features
import
templates
,
parsers
from
openfold.features
import
templates
,
parsers
,
mmcif_parsing
from
openfold.features.np
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
features
=
{}
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
...
...
@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x
=
True
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'domain_name'
]
=
np
.
array
([
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'domain_name'
]
=
np
.
array
(
[
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'sequence'
]
=
np
.
array
([
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'sequence'
]
=
np
.
array
(
[
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
features
def
make_mmcif_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
)
->
FeatureDict
:
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
description
=
'_'
.
join
([
mmcif_object
.
file_id
,
chain_id
])
num_res
=
len
(
input_sequence
)
mmcif_feats
=
{}
mmcif_feats
.
update
(
make_sequence_features
(
sequence
=
input_sequence
,
description
=
description
,
num_res
=
num_res
,
))
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
)
mmcif_feats
[
"all_atom_positions"
]
=
all_atom_positions
mmcif_feats
[
"all_atom_mask"
]
=
all_atom_mask
mmcif_feats
[
"resolution"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"resolution"
]],
dtype
=
np
.
float32
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
mmcif_feats
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
...
...
@@ -58,9 +101,9 @@ def make_msa_features(
)
return
features
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
class
AlignmentRunner
:
""" Runs alignment tools and saves the results """
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
...
...
@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
pdb70_database_path
:
str
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
mgnify_max_hits
:
int
=
501
,
uniref_max_hits
:
int
=
10000
no_cpus
:
int
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
):
"""Constructs a feature dict for a given FASTA file."""
self
.
_use_small_bfd
=
use_small_bfd
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
database_path
=
small_bfd_database_path
,
n_cpu
=
no_cpus
,
)
else
:
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uniclust30_database_path
]
databases
=
[
bfd_database_path
,
uniclust30_database_path
],
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
database_path
=
mgnify_database_path
,
n_cpu
=
no_cpus
,
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
)
self
.
template_featurizer
=
template_featurizer
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
FeatureDict
:
"""Runs alignment tools on the input sequence and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
input_fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
input_fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
input_fasta_path
)[
0
]
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'sto'
],
max_sequences
=
self
.
uniref_max_hits
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
uniref90_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniref90_hits.sto'
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_uniref90_result
[
'sto'
]
)
f
.
write
(
uniref90_msa_as_a3m
)
mgnify_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'mgnify_hits.so'
)
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
'sto'
],
max_sequences
=
self
.
mgnify_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_mgnify_result
[
'sto'
]
)
f
.
write
(
mgnify_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'pdb70_hits.hhr'
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
f
.
write
(
hhsearch_result
)
uniref90_msa
,
uniref90_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_uniref90_result
[
'sto'
]
)
mgnify_msa
,
mgnify_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_mgnify_result
[
'sto'
]
)
hhsearch_hits
=
parsers
.
parse_hhr
(
hhsearch_result
)
mgnify_msa
=
mgnify_msa
[:
self
.
mgnify_max_hits
]
mgnify_deletion_matrix
=
mgnify_deletion_matrix
[:
self
.
mgnify_max_hits
]
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
input_
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
msa_
output_dir
,
'small_bfd_hits.
a3m
'
)
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'small_bfd_hits.
sto
'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
if
(
output_dir
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
]
class
DataPipeline
:
"""Assembles input features."""
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_alignment_output
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'r'
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'r'
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
input_fasta_path
)
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
hhblits_bfd_uniclust_result
[
'a3m'
]
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'r'
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
()
)
if
(
self
.
use_small_bfd
):
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'small_bfd_hits.sto'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
)
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
return
{
'uniref90_msa'
:
uniref90_msa
,
'uniref90_deletion_matrix'
:
uniref90_deletion_matrix
,
'mgnify_msa'
:
mgnify_msa
,
'mgnify_deletion_matrix'
:
mgnify_deletion_matrix
,
'hhsearch_hits'
:
hhsearch_hits
,
'bfd_msa'
:
bfd_msa
,
'bfd_deletion_matrix'
:
bfd_deletion_matrix
,
}
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
hhsearch_hits
hits
=
alignments
[
'
hhsearch_hits
'
]
)
sequence_features
=
make_sequence_features
(
...
...
@@ -180,9 +275,62 @@ class DataPipeline:
)
msa_features
=
make_msa_features
(
msas
=
(
uniref90_msa
,
bfd_msa
,
mgnify_msa
),
deletion_matrices
=
(
uniref90_deletion_matrix
,
bfd_deletion_matrix
,
mgnify_deletion_matrix
)
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
def
process_mmcif
(
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if
(
chain_id
is
None
):
chains
=
mmcif
.
structure
.
get_chains
()
chain
=
next
(
chains
,
None
)
if
(
chain
is
None
):
raise
ValueError
(
'No chains in mmCIF file'
)
chain_id
=
chain
.
id
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
'hhsearch_hits'
]
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
openfold/features/data_transforms.py
View file @
d48c052c
...
...
@@ -6,8 +6,10 @@ import torch
from
operator
import
add
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
,
batched_gather
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
...
...
@@ -59,7 +61,7 @@ def fix_templates_aatype(protein):
num_templates
=
protein
[
'template_aatype'
].
shape
[
0
]
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
r
esidue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
r
c
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
num_templates
,
-
1
)
...
...
@@ -69,8 +71,8 @@ def fix_templates_aatype(protein):
return
protein
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as r
esidue_constants
."""
new_order_list
=
r
esidue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
"""Correct MSA restype to have the same order as r
c
."""
new_order_list
=
r
c
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
'msa'
].
shape
[
1
],
dtype
=
protein
[
'msa'
].
dtype
).
transpose
(
0
,
1
)
...
...
@@ -93,7 +95,7 @@ def squeeze_features(protein):
for
k
in
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask
s
'
]:
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask'
]:
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
...
...
@@ -104,12 +106,6 @@ def squeeze_features(protein):
protein
[
k
]
=
protein
[
k
][
0
]
return
protein
def
make_protein_crop_to_size_seed
(
protein
):
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
)
)
return
protein
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a portion of the MSA with 'X'."""
...
...
@@ -284,19 +280,19 @@ def make_msa_mask(protein):
protein
[
'msa_row_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
s
):
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
):
"""Create pseudo beta features."""
is_gly
=
torch
.
eq
(
aatype
,
r
esidue_constants
.
restype_order
[
'G'
])
ca_idx
=
r
esidue_constants
.
atom_order
[
'CA'
]
cb_idx
=
r
esidue_constants
.
atom_order
[
'CB'
]
is_gly
=
torch
.
eq
(
aatype
,
r
c
.
restype_order
[
'G'
])
ca_idx
=
r
c
.
atom_order
[
'CA'
]
cb_idx
=
r
c
.
atom_order
[
'CB'
]
pseudo_beta
=
torch
.
where
(
torch
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
if
all_atom_mask
s
is
not
None
:
if
all_atom_mask
is
not
None
:
pseudo_beta_mask
=
torch
.
where
(
is_gly
,
all_atom_mask
s
[...,
ca_idx
],
all_atom_mask
s
[...,
cb_idx
])
is_gly
,
all_atom_mask
[...,
ca_idx
],
all_atom_mask
[...,
cb_idx
])
return
pseudo_beta
,
pseudo_beta_mask
else
:
return
pseudo_beta
...
...
@@ -307,9 +303,9 @@ def make_pseudo_beta(protein, prefix=''):
assert
prefix
in
[
''
,
'template_'
]
protein
[
prefix
+
'pseudo_beta'
],
protein
[
prefix
+
'pseudo_beta_mask'
]
=
(
pseudo_beta_fn
(
protein
[
'template_aatype'
if
prefix
else
'
all_atom_
aatype'
],
protein
[
'template_aatype'
if
prefix
else
'aatype'
],
protein
[
prefix
+
'all_atom_positions'
],
protein
[
'template_all_atom_mask
s
'
if
prefix
else
'all_atom_mask'
]))
protein
[
'template_all_atom_mask'
if
prefix
else
'all_atom_mask'
]))
return
protein
@
curry1
...
...
@@ -456,10 +452,12 @@ def make_msa_feat(protein):
protein
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
@
curry1
def
select_feat
(
protein
,
feature_list
):
return
{
k
:
v
for
k
,
v
in
protein
.
items
()
if
k
in
feature_list
}
@
curry1
def
crop_templates
(
protein
,
max_templates
):
for
k
,
v
in
protein
.
items
():
...
...
@@ -467,72 +465,74 @@ def crop_templates(protein, max_templates):
protein
[
k
]
=
v
[:
max_templates
]
return
protein
def
make_atom14_masks
(
protein
):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
restype_atom37_to_atom14
=
[]
restype_atom14_mask
=
[]
for
rt
in
r
esidue_constants
.
restypes
:
atom_names
=
r
esidue_constants
.
restype_name_to_atom14_names
[
r
esidue_constants
.
restype_1to3
[
rt
]
for
rt
in
r
c
.
restypes
:
atom_names
=
r
c
.
restype_name_to_atom14_names
[
r
c
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
(
r
esidue_constants
.
atom_order
[
name
]
if
name
else
0
)
(
r
c
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
r
esidue_constants
.
atom_types
for
name
in
r
c
.
atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
,
device
=
protein
[
'aatype'
].
device
,
)
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
,
device
=
protein
[
'aatype'
].
device
,
)
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
torch
.
float32
restype_atom14_mask
,
dtype
=
torch
.
float32
,
device
=
protein
[
'aatype'
].
device
,
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
torch
.
index_select
(
restype_atom14_to_atom37
,
0
,
protein
[
'aatype'
]
)
residx_atom14_mask
=
torch
.
index_select
(
restype_atom14_mask
,
0
,
protein
[
'aatype'
]
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
'aatype'
]]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
'aatype'
]]
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
.
long
()
# create the gather indices for mapping back
residx_atom37_to_atom14
=
torch
.
index_select
(
restype_atom37_to_atom14
,
0
,
protein
[
'aatype'
]
)
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
'aatype'
]]
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
torch
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
restype_atom37_mask
=
torch
.
zeros
(
[
21
,
37
],
dtype
=
torch
.
float32
,
device
=
protein
[
'aatype'
].
device
)
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
restype_name
=
rc
.
restype_1to3
[
restype_letter
]
atom_names
=
rc
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
r
esidue_constants
.
atom_order
[
atom_name
]
atom_type
=
r
c
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
torch
.
index_select
(
restype_atom37_mask
,
0
,
protein
[
'aatype'
]
)
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
'aatype'
]]
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
...
...
@@ -543,3 +543,546 @@ def make_atom14_masks_np(batch):
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
def
make_atom14_positions
(
protein
):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask
=
protein
[
"atom14_atom_exists"
]
residx_atom14_to_atom37
=
protein
[
"residx_atom14_to_atom37"
]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask
=
residx_atom14_mask
*
batched_gather
(
protein
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
dim
=-
1
,
no_batch_dims
=
len
(
protein
[
"all_atom_mask"
].
shape
[:
-
1
])
)
# Gather the ground truth positions.
residx_atom14_gt_positions
=
residx_atom14_gt_mask
[...,
None
]
*
(
batched_gather
(
protein
[
"all_atom_positions"
],
residx_atom14_to_atom37
,
dim
=-
2
,
no_batch_dims
=
len
(
protein
[
"all_atom_positions"
].
shape
[:
-
2
])
)
)
protein
[
"atom14_atom_exists"
]
=
residx_atom14_mask
protein
[
"atom14_gt_exists"
]
=
residx_atom14_gt_mask
protein
[
"atom14_gt_positions"
]
=
residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3
=
[
rc
.
restype_1to3
[
res
]
for
res
in
rc
.
restypes
]
restype_3
+=
[
"UNK"
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
torch
.
eye
(
14
,
dtype
=
protein
[
"all_atom_mask"
].
dtype
,
device
=
protein
[
"all_atom_mask"
].
device
)
for
res
in
restype_3
}
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
torch
.
arange
(
14
,
device
=
protein
[
"all_atom_mask"
].
device
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
protein
[
"all_atom_mask"
].
new_zeros
((
14
,
14
))
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform
=
renaming_matrices
[
protein
[
"aatype"
]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions
=
torch
.
einsum
(
"...rac,...rab->...rbc"
,
residx_atom14_gt_positions
,
renaming_transform
)
protein
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask
=
torch
.
einsum
(
"...ra,...rab->...rb"
,
residx_atom14_gt_mask
,
renaming_transform
)
protein
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous
=
protein
[
"all_atom_mask"
].
new_zeros
((
21
,
14
))
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]]
atom_idx1
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# From this create an ambiguous_mask for the given sequence.
protein
[
"atom14_atom_is_ambiguous"
]
=
(
restype_atom14_is_ambiguous
[
protein
[
"aatype"
]]
)
return
protein
def
atom37_to_frames
(
protein
):
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
all_atom_mask
.
new_zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
)
restype_rigidgroup_base_atom37_idx
=
aatype
.
new_tensor
(
restype_rigidgroup_base_atom37_idx
,
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
)
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
1e-8
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
residx_rigidgroup_is_ambiguous
=
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
protein
[
'rigidgroups_gt_frames'
]
=
gt_frames_tensor
protein
[
'rigidgroups_gt_exists'
]
=
gt_exists
protein
[
'rigidgroups_group_exists'
]
=
group_exists
protein
[
'rigidgroups_group_is_ambiguous'
]
=
residx_rigidgroup_is_ambiguous
protein
[
'rigidgroups_alt_gt_frames'
]
=
alt_gt_frames_tensor
return
protein
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
chi_atom_indices
@
curry1
def
atom37_to_torsion_angles
(
protein
,
prefix
=
''
,
):
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Dict containing:
* (prefix)aatype:
[*, N_res] residue indices
* (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
* (prefix)all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype
=
protein
[
prefix
+
"aatype"
]
all_atom_positions
=
protein
[
prefix
+
"all_atom_positions"
]
all_atom_mask
=
protein
[
prefix
+
"all_atom_mask"
]
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_positions
.
new_zeros
(
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
phi_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
psi_atom_pos
=
torch
.
cat
(
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
pre_omega_mask
=
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
chi_atom_indices
=
torch
.
as_tensor
(
get_chi_atom_indices
(),
device
=
aatype
.
device
)
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
all_atom_positions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chi_angle_atoms_mask
=
batched_gather
(
all_atom_mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
[
pre_omega_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
],
dim
=-
3
)
torsion_angles_mask
=
torch
.
cat
(
[
pre_omega_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
],
dim
=-
1
)
torsion_frames
=
T
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
1e-8
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
torsions_atom_pos
[...,
3
,
:]
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
1e-8
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
rc
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
)
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[...,
None
]
)
protein
[
prefix
+
"torsion_angles_sin_cos"
]
=
torsion_angles_sin_cos
protein
[
prefix
+
"alt_torsion_angles_sin_cos"
]
=
alt_torsion_angles_sin_cos
protein
[
prefix
+
"torsion_angles_mask"
]
=
torsion_angles_mask
return
protein
def
get_backbone_frames
(
protein
):
# TODO: Verify that this is correct
protein
[
"backbone_affine_tensor"
]
=
(
protein
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
protein
[
"backbone_affine_mask"
]
=
(
protein
[
"rigidgroups_gt_exists"
][...,
0
]
)
return
protein
def
get_chi_angles
(
protein
):
dtype
=
protein
[
"all_atom_mask"
].
dtype
protein
[
"chi_angles_sin_cos"
]
=
(
protein
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
dtype
)
protein
[
"chi_mask"
]
=
protein
[
"torsion_angles_mask"
][...,
3
:].
to
(
dtype
)
return
protein
@
curry1
def
random_crop_to_size
(
protein
,
crop_size
,
max_templates
,
shape_schema
,
subsample_templates
=
False
,
seed
=
None
,
batch_mode
=
'clamped'
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
'seq_length'
]
if
'template_mask'
in
protein
:
num_templates
=
protein
[
'template_mask'
].
shape
[
-
1
]
else
:
num_templates
=
protein
[
'aatype'
].
new_zeros
((
1
,))
num_res_crop_size
=
min
(
seq_length
,
crop_size
)
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
'seq_length'
].
device
)
if
(
seed
is
not
None
):
g
.
manual_seed
(
seed
)
def
_randint
(
lower
,
upper
):
return
int
(
torch
.
randint
(
lower
,
upper
,
(
1
,),
device
=
protein
[
'seq_length'
].
device
,
generator
=
g
)[
0
])
if
subsample_templates
:
templates_crop_start
=
_randint
(
0
,
num_templates
+
1
)
templates_select_indices
=
torch
.
randperm
(
num_templates
,
device
=
protein
[
'seq_length'
].
device
,
generator
=
g
)
num_templates_crop_size
=
min
(
num_templates
-
templates_crop_start
,
max_templates
)
else
:
templates_crop_start
=
0
num_templates_crop_size
=
num_templates
n
=
seq_length
-
num_res_crop_size
if
(
batch_mode
==
'clamped'
):
right_anchor
=
n
+
1
elif
(
batch_mode
==
'unclamped'
):
x
=
_randint
(
0
,
n
)
right_anchor
=
n
-
x
+
1
else
:
raise
ValueError
(
"Invalid batch mode"
)
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
for
k
,
v
in
protein
.
items
():
if
(
k
not
in
shape_schema
or
(
'template'
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
])
):
continue
# randomly permute the templates before cropping them.
if
k
.
startswith
(
'template'
)
and
subsample_templates
:
v
=
v
[
templates_select_indices
]
slices
=
[]
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
is_num_res
=
(
dim_size
==
NUM_RES
)
if
i
==
0
and
k
.
startswith
(
'template'
):
crop_size
=
num_templates_crop_size
crop_start
=
templates_crop_start
else
:
crop_start
=
num_res_crop_start
if
is_num_res
else
0
crop_size
=
num_res_crop_size
if
is_num_res
else
dim
slices
.
append
(
slice
(
crop_start
,
crop_start
+
crop_size
))
protein
[
k
]
=
v
[
slices
]
protein
[
'seq_length'
]
=
(
protein
[
'seq_length'
].
new_tensor
(
num_res_crop_size
)
)
return
protein
openfold/features/feature_pipeline.py
View file @
d48c052c
...
...
@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
return
tensor_dict
def
make_data_config
(
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
num_res
:
int
,
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
.
data
)
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
)
mode_cfg
=
cfg
[
mode
]
with
cfg
.
unlocked
():
if
(
mode_cfg
.
crop_size
is
None
):
mode_cfg
.
crop_size
=
num_res
feature_names
=
cfg
.
common
.
unsupervised_features
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
with
cfg
.
unlock
ed
(
):
cfg
.
eval
.
crop_size
=
num_
res
if
(
cfg
[
mode
].
supervis
ed
):
feature_names
+=
cfg
.
common
.
supervised_featu
res
return
cfg
,
feature_names
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
random_seed
:
int
=
0
):
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
np_example
:
np_example
[
'deletion_matrix'
]
=
(
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
))
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
)
)
if
batch_mode
==
'clamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
1.
).
astype
(
np
.
float32
)
)
elif
batch_mode
==
'unclamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
0.
).
astype
(
np
.
float32
)
)
torch
.
manual_seed
(
random_seed
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
np_example
=
np_example
,
features
=
feature_names
)
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
...
@@ -70,10 +98,13 @@ class FeaturePipeline:
self
.
params
=
params
def
process_features
(
self
,
raw_features
:
FeatureDict
,
random_seed
:
int
)
->
FeatureDict
:
raw_features
:
FeatureDict
,
mode
:
str
=
'train'
,
batch_mode
:
str
=
'clamped'
,
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
random_seed
=
random_seed
)
\ No newline at end of file
mode
=
mode
,
batch_mode
=
batch_mode
,
)
openfold/features/input_pipeline.py
View file @
d48c052c
from
functools
import
partial
import
torch
from
openfold.features
import
data_transforms
def
nonensembled_transform_fns
(
data_confi
g
):
def
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
):
"""Input pipeline data transformers that are not ensembled."""
common_cfg
=
data_config
.
common
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
correct_msa_restypes
,
...
...
@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
if
(
common_cfg
.
use_template_torsion_angles
):
transforms
.
extend
([
data_transforms
.
atom37_to_torsion_angles
(
'template_'
),
])
transforms
.
extend
([
data_transforms
.
make_atom14_masks
,
])
if
(
mode_cfg
.
supervised
):
transforms
.
extend
([
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
''
),
data_transforms
.
make_pseudo_beta
(
''
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
])
return
transforms
def
ensembled_transform_fns
(
data_config
):
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
):
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg
=
data_config
.
common
eval_cfg
=
data_config
.
eval
transforms
=
[]
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
-
eval
_cfg
.
max_templates
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
-
mode
_cfg
.
max_templates
else
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
...
...
@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
eval_cfg
.
masked_msa_replace_fraction
)
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
)
)
if
common_cfg
.
msa_cluster_features
:
...
...
@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms
.
append
(
data_transforms
.
make_msa_feat
())
crop_feats
=
dict
(
eval
_cfg
.
feat
)
crop_feats
=
dict
(
common
_cfg
.
feat
)
if
eval
_cfg
.
fixed_size
:
if
mode
_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
torch
.
Generator
().
seed
()
))
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
eval
_cfg
.
crop_size
,
eval
_cfg
.
max_templates
mode
_cfg
.
crop_size
,
mode
_cfg
.
max_templates
))
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
eval_cfg
.
max_templates
))
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
)
return
transforms
def
process_tensors_from_config
(
tensors
,
data_config
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
'clamped'
):
"""Based on the config, apply filters and transformations to the data."""
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
data_config
)
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fn
=
compose
(
fns
)
d
[
'ensemble_index'
]
=
i
return
fn
(
d
)
eval_cfg
=
data_config
.
eval
tensors
=
compose
(
nonensembled_transform_fns
(
data_confi
g
)
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
)
)(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
eval
_cfg
.
num_ensemble
if
data_config
.
common
.
resample_msa_in_recycling
:
num_ensemble
=
mode
_cfg
.
num_ensemble
if
common
_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
data_config
.
common
.
num_recycle
+
1
num_ensemble
*=
common
_cfg
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
...
...
@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
([
dict_i
[
feat
]
for
dict_i
in
ensembles
])
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/features/mmcif_parsing.py
View file @
d48c052c
"""Parses the mmCIF file format."""
import
collections
import
dataclasses
import
io
import
json
import
logging
import
os
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
from
Bio
import
PDB
from
Bio.Data
import
SCOPData
import
numpy
as
np
import
openfold.np.residue_constants
as
residue_constants
# Type aliases:
ChainId
=
str
...
...
@@ -369,3 +374,73 @@ def _get_protein_chains(
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
# Extract the coordinates
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
chain_id
])
all_atom_positions
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
all_atom_mask
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coords of the selenium atom in the sulphur column
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Could not parse
{
f
}
. Skipping...'
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
'release_date'
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
'no_chains'
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
'w'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
openfold/features/np/hhsearch.py
View file @
d48c052c
...
...
@@ -18,6 +18,7 @@ class HHSearch:
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
n_cpu
:
int
=
2
,
maxseq
:
int
=
1_000_000
):
"""Initializes the Python HHsearch wrapper.
...
...
@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
...
...
@@ -34,6 +36,7 @@ class HHSearch:
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
self
.
n_cpu
=
n_cpu
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
...
...
@@ -56,7 +59,8 @@ class HHSearch:
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
)
'-maxseq'
,
str
(
self
.
maxseq
),
'-cpu'
,
str
(
self
.
n_cpu
),
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
...
...
openfold/features/np/jackhmmer.py
View file @
d48c052c
...
...
@@ -3,14 +3,12 @@
from
concurrent
import
futures
import
glob
import
logging
import
os
import
subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
absl
import
logging
from
openfold.features.np
import
utils
...
...
openfold/features/np/utils.py
View file @
d48c052c
"""Common utilities for data pipeline tools."""
import
contextlib
import
datetime
import
shutil
import
tempfile
import
time
...
...
@@ -25,3 +26,9 @@ def timing(msg: str):
yield
toc
=
time
.
time
()
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
def
to_date
(
s
:
str
):
return
datetime
.
datetime
(
year
=
int
(
s
[:
4
]),
month
=
int
(
s
[
5
:
7
]),
day
=
int
(
s
[
8
:
10
])
)
openfold/features/templates.py
View file @
d48c052c
...
...
@@ -2,16 +2,17 @@
import
dataclasses
import
datetime
import
glob
import
json
import
logging
import
os
import
re
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
import
numpy
as
np
from
openfold.features
import
parsers
,
mmcif_parsing
from
openfold.features.np
import
kalign
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
...
...
@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES
=
{
'template_aatype'
:
np
.
int64
,
'template_all_atom_mask
s
'
:
np
.
float32
,
'template_all_atom_mask'
:
np
.
float32
,
'template_all_atom_positions'
:
np
.
float32
,
'template_domain_names'
:
np
.
object
,
'template_sequence'
:
np
.
object
,
...
...
@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return
result
def
generate_release_dates_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
dates
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
path
=
os
.
path
.
join
(
mmcif_dir
,
f
)
with
open
(
path
,
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
. Skipping...'
)
continue
mmcif
=
mmcif
.
mmcif_object
release_date
=
mmcif
.
header
[
'release_date'
]
dates
[
file_id
]
=
release_date
with
open
(
out_path
,
'r'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
dates
))
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if
path
.
endswith
(
'txt'
):
release_dates
=
{}
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
pdb_id
,
date
=
line
.
split
(
':'
)
date
=
date
.
strip
()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates
[
pdb_id
.
strip
()]
=
datetime
.
datetime
(
year
=
int
(
date
[:
4
]),
month
=
int
(
date
[
5
:
7
]),
day
=
int
(
date
[
8
:
10
]))
return
release_dates
else
:
raise
ValueError
(
'Invalid format of the release date file %s.'
%
path
)
with
open
(
path
,
'r'
)
as
fp
:
data
=
json
.
load
(
fp
)
return
{
pdb
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
if
k
==
"release_date"
}
def
_assess_hhsearch_hit
(
hit
:
parsers
.
TemplateHit
,
...
...
@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
auth_chain_id
])
relevant_chains
=
[
c
for
c
in
mmcif_object
.
structure
.
get_chains
()
if
c
.
id
==
auth_chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
auth_chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
all_positions
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
,
3
])
all_positions_mask
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
int64
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
auth_chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coordinates of the selenium atom in the sulphur column.
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_positions
[
res_index
]
=
pos
all_positions_mask
[
res_index
]
=
mask
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
)
all_atom_positions
,
all_atom_mask
=
coords_with_mask
_check_residue_distances
(
all_positions
,
all_positions_mask
,
max_ca_ca_distance
)
return
all_positions
,
all_positions_mask
all_atom_positions
,
all_atom_mask
,
max_ca_ca_distance
)
return
all_atom_positions
,
all_atom_mask
def
_extract_template_features
(
...
...
@@ -579,7 +569,7 @@ def _extract_template_features(
return
(
{
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
'template_all_atom_mask
s
'
:
np
.
array
(
templates_all_atom_masks
),
'template_all_atom_mask'
:
np
.
array
(
templates_all_atom_masks
),
'template_sequence'
:
output_templates_sequence
.
encode
(),
'template_aatype'
:
np
.
array
(
templates_aatype
),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
...
...
openfold/model/model.py
View file @
d48c052c
...
...
@@ -19,7 +19,6 @@ import torch.nn as nn
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
atom37_to_torsion_angles
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
...
...
@@ -115,21 +114,16 @@ class AlphaFold(nn.Module):
batch
,
)
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_all_atom_positions"
],
#.float(),
single_template_feats
[
"template_all_atom_masks"
],
#.float(),
eps
=
self
.
config
.
template
.
eps
,
)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
single_template_feats
[
"template_aatype"
],
)
single_template_embeds
=
{}
if
(
self
.
config
.
template
.
embed_angles
):
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
...
...
@@ -145,11 +139,11 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
)
template_embeds
.
append
({
"angle"
:
a
,
"pair"
:
t
,
"torsion_mask"
:
angle_feats
[
"torsion_angles_mask"
]
single_template_embeds
.
update
({
"pair"
:
t
,
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
...
@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
return
{
"template_angle_embedding"
:
template_embeds
[
"angle"
],
ret
=
{}
if
(
self
.
config
.
template
.
embed_angles
):
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
,
"torsion_angles_mask"
:
template_embeds
[
"torsion_mask"
],
}
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Primary output dictionary
...
...
@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if
(
self
.
config
.
n
o_
cycle
s
>
1
):
if
(
self
.
config
.
n
um_re
cycle
>
0
):
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
...
...
@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
"template_"
in
k
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
...
...
@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
)
# [*, S, N]
torsion_angles_mask
=
template_
embeds
[
"
torsion_angles_mask"
]
torsion_angles_mask
=
feats
[
"
template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
)
...
...
@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
...
...
@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
n
o_
cycle
s
):
for
cycle_no
in
range
(
self
.
config
.
n
um_re
cycle
+
1
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
n
o_cycles
-
1
)
)
is_final_iter
=
(
cycle_no
==
self
.
config
.
n
um_recycle
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
):
...
...
openfold/utils/exponential_moving_average.py
View file @
d48c052c
...
...
@@ -29,14 +29,15 @@ class ExponentialMovingAverage:
self
.
decay
=
decay
def
_update_state_dict_
(
self
,
update
,
state_dict
):
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
with
torch
.
no_grad
():
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
...
...
openfold/utils/feats.py
View file @
d48c052c
...
...
@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return
pseudo_beta
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
chi_atom_indices
def
atom14_to_atom37
(
atom14
,
batch
):
atom37_data
=
batched_gather
(
atom14
,
...
...
@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return
atom37_data
def
atom37_to_torsion_angles
(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_positions
.
new_zeros
(
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
phi_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
psi_atom_pos
=
torch
.
cat
(
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
pre_omega_mask
=
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
chi_atom_indices
=
torch
.
as_tensor
(
get_chi_atom_indices
(),
device
=
aatype
.
device
)
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
all_atom_positions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chi_angle_atoms_mask
=
batched_gather
(
all_atom_mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
[
pre_omega_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
],
dim
=-
3
)
torsion_angles_mask
=
torch
.
cat
(
[
pre_omega_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
],
dim
=-
1
)
torsion_frames
=
T
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
torsions_atom_pos
[...,
3
,
:]
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
rc
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
)
def
build_template_angle_feat
(
template_feats
):
template_aatype
=
template_feats
[
"template_aatype"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[...,
None
]
)
return
{
"torsion_angles_sin_cos"
:
torsion_angles_sin_cos
,
"alt_torsion_angles_sin_cos"
:
alt_torsion_angles_sin_cos
,
"torsion_angles_mask"
:
torsion_angles_mask
,
}
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
all_atom_mask
.
new_zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
)
restype_rigidgroup_base_atom37_idx
=
aatype
.
new_tensor
(
restype_rigidgroup_base_atom37_idx
,
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
template_feats
[
"template_alt_torsion_angles_sin_cos"
]
)
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
residx_rigidgroup_is_ambiguous
=
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
return
{
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
'rigidgroups_gt_exists'
:
gt_exists
,
'rigidgroups_group_exists'
:
group_exists
,
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
'rigidgroups_alt_gt_frames'
:
alt_gt_frames_tensor
,
}
def
build_template_angle_feat
(
angle_feats
,
template_aatype
):
torsion_angles_sin_cos
=
angle_feats
[
"torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
angle_feats
[
"alt_torsion_angles_sin_cos"
]
torsion_angles_mask
=
angle_feats
[
"torsion_angles_mask"
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
template_angle_feat
=
torch
.
cat
(
[
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
...
...
@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
t_aa_masks
=
batch
[
"template_all_atom_mask
s
"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
...
...
@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return
batch
def
build_ambiguity_feats
(
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous
=
ambiguous_atoms
[
batch
[
"aatype"
],
...]
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
batch
[
"atom14_gt_positions"
][...,
None
,
:]
*
swap_mat
[...,
None
],
dim
=-
3
)
)
atom14_alt_gt_exists
=
(
torch
.
sum
(
batch
[
"atom14_gt_exists"
][...,
None
]
*
swap_mat
,
dim
=-
2
)
)
return
{
"atom14_atom_is_ambiguous"
:
atom14_atom_is_ambiguous
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
def
torsion_angles_to_frames
(
t
:
T
,
alpha
:
torch
.
Tensor
,
...
...
openfold/utils/loss.py
View file @
d48c052c
...
...
@@ -18,6 +18,7 @@ import ml_collections
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.distributions.bernoulli
import
Bernoulli
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
...
...
@@ -117,7 +118,9 @@ def compute_fape(
return
normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
# DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def
backbone_loss
(
backbone_affine_tensor
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
...
...
@@ -130,7 +133,7 @@ def backbone_loss(
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[...,
None
,
:],
...
...
@@ -142,7 +145,6 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
if
(
use_clamped_fape
is
not
None
):
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
...
...
@@ -157,12 +159,12 @@ def backbone_loss(
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
# Take the mean over the layer dimension
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
0
)
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
-
1
)
return
fape_loss
...
...
@@ -1453,7 +1455,7 @@ class AlphaFoldLoss(nn.Module):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
out
[
"violation"
]
=
find_structural_violations
(
batch
,
...
...
@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**
self
.
config
.
violation
,
)
if
(
"atom14_atom_is_ambiguous"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
build_ambiguity_feats
(
batch
))
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
batch
[
"backbone_affine_mask"
]
=
(
batch
[
"rigidgroups_gt_exists"
][...,
0
]
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
with
torch
.
no_grad
():
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
eps
=
self
.
config
.
eps
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:].
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
"distogram"
:
lambda
:
distogram_loss
(
...
...
run_pretrained_openfold.py
View file @
d48c052c
...
...
@@ -15,17 +15,17 @@
import
argparse
from
datetime
import
date
import
pickle
import
logging
import
os
# A hack to get OpenMM and PyTorch to peacefully coexist
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
import
pickle
import
random
import
sys
from
openfold.features
import
templates
,
feature_pipeline
from
openfold.features.np
import
data_pipeline
from
openfold.features
import
templates
,
feature_pipeline
,
data_pipeline
import
time
...
...
@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
)
MAX_TEMPLATE_HITS
=
20
from
scripts.utils
import
add_data_args
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
.
model
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
model
=
model
.
to
(
args
.
device
)
model
=
model
.
to
(
args
.
model_
device
)
# FEATURE COLLECTION AND PROCESSING
use_small_bfd
=
args
.
preset
==
"reduced_dbs"
num_ensemble
=
1
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
MAX_TEMPLATE_HITS
,
max_hits
=
args
.
max_template_hits
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
...
...
@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
data_processor
=
data_pipeline
.
DataPipeline
(
...
...
@@ -87,7 +89,7 @@ def main(args):
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
eval
.
num_ensemble
=
num_ensemble
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
...
...
@@ -95,7 +97,7 @@ def main(args):
if
not
os
.
path
.
exists
(
alignment_dir
):
os
.
makedirs
(
alignment_dir
)
print
(
"Generating features..."
)
logging
.
info
(
"Generating features..."
)
alignment_runner
.
run
(
args
.
fasta_path
,
alignment_dir
)
...
...
@@ -105,42 +107,20 @@ def main(args):
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
random_seed
feature_dict
,
mode
=
'predict'
,
)
for
k
,
v
in
processed_feature_dict
.
items
():
print
(
k
)
print
(
v
.
shape
)
print
(
"Executing model..."
)
logging
.
info
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_
device
)
for
k
,
v
in
batch
.
items
()
}
longs
=
[
"aatype"
,
"template_aatype"
,
"extra_msa"
,
"residx_atom37_to_atom14"
,
"residx_atom14_to_atom37"
,
"true_msa"
,
"residue_index"
,
]
for
l
in
longs
:
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
make_contig
=
lambda
t
:
t
.
contiguous
()
batch
=
tensor_tree_map
(
make_contig
,
batch
)
t
=
time
.
time
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
...
...
@@ -158,9 +138,7 @@ def main(args):
result
=
out
,
b_factors
=
plddt_b_factors
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"7"
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
)
...
...
@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction.
t
=
time
.
time
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
...
...
@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
)
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'--uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
required
=
True
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cpu"
,
"--
model_
device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
...
...
@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from
openfold/resources/params"""
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"""Number of CPUs to use to run alignment tools"""
)
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
...
...
scripts/build_deepspeed_config.py
View file @
d48c052c
...
...
@@ -15,6 +15,7 @@
import
argparse
import
json
parser
=
argparse
.
ArgumentParser
(
description
=
'''Outputs a DeepSpeed
configuration file to
stdout'''
)
...
...
scripts/precompute_alignments.py
0 → 100644
View file @
d48c052c
import
argparse
import
logging
import
os
import
tempfile
import
openfold.features.mmcif_parsing
as
mmcif_parsing
from
openfold.features.data_pipeline
import
AlignmentRunner
from
scripts.utils
import
add_data_args
def
main
(
args
):
# Build the alignment tool runner
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus
,
)
for
f
in
os
.
listdir
(
args
.
input_dir
):
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
is_mmcif
=
f
.
endswith
(
'.cif'
)
is_fasta
=
f
.
endswith
(
'.fasta'
)
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
seqs
=
{}
if
(
is_mmcif
):
with
open
(
path
,
'r'
)
as
fp
:
mmcif_str
=
fp
.
read
()
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_str
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
...'
)
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
else
:
continue
mmcif
=
mmcif
.
mmcif_object
for
k
,
v
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
k
])
seqs
[
chain_id
]
=
v
elif
(
is_fasta
):
with
open
(
path
,
'r'
)
as
fp
:
fasta_str
=
fp
.
read
()
input_seqs
,
_
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
msg
=
f
'More than one input_sequence found in
{
f
}
'
if
(
args
.
raise_errors
):
raise
ValueError
(
msg
)
else
:
logging
.
warning
(
msg
)
input_sequence
=
input_seqs
[
0
]
seqs
[
file_id
]
=
input_sequence
else
:
continue
for
name
,
seq
in
seqs
.
items
():
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
if
(
os
.
path
.
isdir
(
alignment_dir
)):
logging
.
info
(
f
'
{
f
}
has already been processed. Skipping...'
)
continue
os
.
makedirs
(
alignment_dir
)
if
(
not
is_fasta
):
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
alignment_runner
.
run
(
f
if
is_fasta
else
fasta_path
,
alignment_dir
)
if
(
not
is_fasta
):
os
.
remove
(
fasta_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_dir"
,
type
=
str
,
help
=
"Path to directory containing mmCIF and/or FASTA files"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
"Directory in which to output alignments"
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--raise_errors"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to crash on parsing errors"
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"Number of CPUs to use"
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/utils.py
0 → 100644
View file @
d48c052c
import
argparse
from
datetime
import
date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--max_template_hits'
,
type
=
int
,
default
=
20
,
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
tests/compare_utils.py
View file @
d48c052c
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4,"
import
importlib
import
pkgutil
import
sys
...
...
tests/test_feats.py
View file @
d48c052c
...
...
@@ -16,6 +16,7 @@ import torch
import
numpy
as
np
import
unittest
import
openfold.features.data_transforms
as
data_transforms
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
feat
s
.
atom37_to_frames
(
eps
=
1e-8
,
**
batch
)
out_repro
=
data_transform
s
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment