Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d48c052c
"...models/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "2400fdf22fb3279f0a274d4678d13bf547315276"
Commit
d48c052c
authored
Oct 15, 2021
by
Gustaf Ahdritz
Browse files
Add training parsers
parent
eeda001c
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1676 additions
and
919 deletions
+1676
-919
openfold/config.py
openfold/config.py
+281
-233
openfold/features/data_pipeline.py
openfold/features/data_pipeline.py
+336
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+590
-47
openfold/features/feature_pipeline.py
openfold/features/feature_pipeline.py
+48
-17
openfold/features/input_pipeline.py
openfold/features/input_pipeline.py
+53
-24
openfold/features/mmcif_parsing.py
openfold/features/mmcif_parsing.py
+77
-2
openfold/features/np/hhsearch.py
openfold/features/np/hhsearch.py
+5
-1
openfold/features/np/jackhmmer.py
openfold/features/np/jackhmmer.py
+1
-3
openfold/features/np/utils.py
openfold/features/np/utils.py
+7
-0
openfold/features/templates.py
openfold/features/templates.py
+43
-53
openfold/model/model.py
openfold/model/model.py
+28
-29
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+9
-8
openfold/utils/feats.py
openfold/utils/feats.py
+6
-386
openfold/utils/loss.py
openfold/utils/loss.py
+8
-35
run_pretrained_openfold.py
run_pretrained_openfold.py
+23
-80
scripts/build_deepspeed_config.py
scripts/build_deepspeed_config.py
+1
-0
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+107
-0
scripts/utils.py
scripts/utils.py
+48
-0
tests/compare_utils.py
tests/compare_utils.py
+3
-0
tests/test_feats.py
tests/test_feats.py
+2
-1
No files found.
openfold/config.py
View file @
d48c052c
...
...
@@ -6,42 +6,42 @@ def set_inf(c, inf):
for
k
,
v
in
c
.
items
():
if
(
isinstance
(
v
,
mlc
.
ConfigDict
)):
set_inf
(
v
,
inf
)
elif
(
k
==
"
inf
"
):
elif
(
k
==
'
inf
'
):
c
[
k
]
=
inf
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
if
(
name
==
"
model_1
"
):
if
(
name
==
'
model_1
'
):
pass
elif
(
name
==
"
model_2
"
):
elif
(
name
==
'
model_2
'
):
pass
elif
(
name
==
"
model_3
"
):
elif
(
name
==
'
model_3
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_4
"
):
elif
(
name
==
'
model_4
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_5
"
):
elif
(
name
==
'
model_5
'
):
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"
model_1_ptm
"
):
elif
(
name
==
'
model_1_ptm
'
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_2_ptm
"
):
elif
(
name
==
'
model_2_ptm
'
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_3_ptm
"
):
elif
(
name
==
'
model_3_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_4_ptm
"
):
elif
(
name
==
'
model_4_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"
model_5_ptm
"
):
elif
(
name
==
'
model_5_ptm
'
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
else
:
raise
ValueError
(
"
Invalid model name
"
)
raise
ValueError
(
'
Invalid model name
'
)
if
(
train
):
c
.
globals
.
blocks_per_ckpt
=
1
...
...
@@ -65,6 +65,9 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
num_recycle
=
mlc
.
FieldReference
(
3
,
field_type
=
int
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
NUM_RES
=
'num residues placeholder'
NUM_MSA_SEQ
=
'msa placeholder'
...
...
@@ -74,29 +77,7 @@ NUM_TEMPLATES = 'num templates placeholder'
config
=
mlc
.
ConfigDict
({
'data'
:
{
'common'
:
{
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
3
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_masks'
,
# 'template_domain_names'
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'msa'
,
# 'sequence', #'domain_name',
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
True
,
},
'eval'
:
{
'batch_modes'
:
[(
'clamped'
,
0.9
),
(
'unclamped'
,
0.1
)],
'feat'
:
{
'aatype'
:
[
NUM_RES
],
'all_atom_mask'
:
[
NUM_RES
,
None
],
...
...
@@ -110,7 +91,7 @@ config = mlc.ConfigDict({
'atom14_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom37_atom_exists'
:
[
NUM_RES
,
None
],
'backbone_affine_mask'
:
[
NUM_RES
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
,
None
],
'bert_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'chi_angles'
:
[
NUM_RES
,
None
],
'chi_mask'
:
[
NUM_RES
,
None
],
...
...
@@ -125,266 +106,333 @@ config = mlc.ConfigDict({
'msa_row_mask'
:
[
NUM_MSA_SEQ
],
'pseudo_beta'
:
[
NUM_RES
,
None
],
'pseudo_beta_mask'
:
[
NUM_RES
],
'random_crop_to_size_seed'
:
[
None
],
'residue_index'
:
[
NUM_RES
],
'residx_atom14_to_atom37'
:
[
NUM_RES
,
None
],
'residx_atom37_to_atom14'
:
[
NUM_RES
,
None
],
'resolution'
:
[],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
'rigidgroups_group_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_group_is_ambiguous'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
'seq_length'
:
[],
'seq_mask'
:
[
NUM_RES
],
'target_feat'
:
[
NUM_RES
,
None
],
'template_aatype'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_all_atom_masks'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_all_atom_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_alt_torsion_angles_sin_cos'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_backbone_affine_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_backbone_affine_tensor'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_mask'
:
[
NUM_TEMPLATES
],
'template_pseudo_beta'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_pseudo_beta_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_sum_probs'
:
[
NUM_TEMPLATES
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
]
'template_torsion_angles_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_torsion_angles_sin_cos'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'use_clamped_fape'
:
[],
},
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
num_recycle
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_mask'
,
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
templates_enabled
,
'use_template_torsion_angles'
:
embed_template_torsion_angles
,
'supervised_features'
:
[
'all_atom_mask'
,
'all_atom_positions'
,
'resolution'
,
'use_clamped_fape'
,
],
},
'predict'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
False
,
'crop_size'
:
None
,
'supervised'
:
False
,
},
'eval'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
False
,
'crop_size'
:
None
,
'supervised'
:
True
,
},
'train'
:
{
'fixed_size'
:
True
,
'subsample_templates'
:
True
,
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
'crop'
:
True
,
'crop_size'
:
256
,
'supervised'
:
True
,
},
'data_module'
:
{
'use_small_bfd'
:
False
,
'data_loaders'
:
{
'batch_size'
:
1
,
'num_workers'
:
1
,
},
}
},
# Recurring FieldReferences that can be changed globally here
"
globals
"
:
{
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
c_t
"
:
c_t
,
"
c_e
"
:
c_e
,
"
c_s
"
:
c_s
,
"
eps
"
:
eps
,
'
globals
'
:
{
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
c_t
'
:
c_t
,
'
c_e
'
:
c_e
,
'
c_s
'
:
c_s
,
'
eps
'
:
eps
,
},
"
model
"
:
{
"no_cycles"
:
4
,
"
_mask_trans
"
:
False
,
"
input_embedder
"
:
{
"
tf_dim
"
:
22
,
"
msa_dim
"
:
49
,
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
relpos_k
"
:
32
,
'
model
'
:
{
'num_recycle'
:
num_recycle
,
'
_mask_trans
'
:
False
,
'
input_embedder
'
:
{
'
tf_dim
'
:
22
,
'
msa_dim
'
:
49
,
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
relpos_k
'
:
32
,
},
"
recycling_embedder
"
:
{
"
c_z
"
:
c_z
,
"
c_m
"
:
c_m
,
"
min_bin
"
:
3.25
,
"
max_bin
"
:
20.75
,
"
no_bins
"
:
15
,
"
inf
"
:
1e8
,
'
recycling_embedder
'
:
{
'
c_z
'
:
c_z
,
'
c_m
'
:
c_m
,
'
min_bin
'
:
3.25
,
'
max_bin
'
:
20.75
,
'
no_bins
'
:
15
,
'
inf
'
:
1e8
,
},
"
template
"
:
{
"
distogram
"
:
{
"
min_bin
"
:
3.25
,
"
max_bin
"
:
50.75
,
"
no_bins
"
:
39
,
'
template
'
:
{
'
distogram
'
:
{
'
min_bin
'
:
3.25
,
'
max_bin
'
:
50.75
,
'
no_bins
'
:
39
,
},
"
template_angle_embedder
"
:
{
'
template_angle_embedder
'
:
{
# DISCREPANCY: c_in is supposed to be 51.
"
c_in
"
:
57
,
"
c_out
"
:
c_m
,
'
c_in
'
:
57
,
'
c_out
'
:
c_m
,
},
"
template_pair_embedder
"
:
{
"
c_in
"
:
88
,
"
c_out
"
:
c_t
,
'
template_pair_embedder
'
:
{
'
c_in
'
:
88
,
'
c_out
'
:
c_t
,
},
"
template_pair_stack
"
:
{
"
c_t
"
:
c_t
,
'
template_pair_stack
'
:
{
'
c_t
'
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"
c_hidden_tri_att
"
:
16
,
"
c_hidden_tri_mul
"
:
64
,
"
no_blocks
"
:
2
,
"
no_heads
"
:
4
,
"
pair_transition_n
"
:
2
,
"
dropout_rate
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
'
c_hidden_tri_att
'
:
16
,
'
c_hidden_tri_mul
'
:
64
,
'
no_blocks
'
:
2
,
'
no_heads
'
:
4
,
'
pair_transition_n
'
:
2
,
'
dropout_rate
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
},
"
template_pointwise_attention
"
:
{
"
c_t
"
:
c_t
,
"
c_z
"
:
c_z
,
'
template_pointwise_attention
'
:
{
'
c_t
'
:
c_t
,
'
c_z
'
:
c_z
,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"
c_hidden
"
:
16
,
"
no_heads
"
:
4
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
'
c_hidden
'
:
16
,
'
no_heads
'
:
4
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
},
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-6,
"
enabled
"
:
True
,
"
embed_angles
"
:
True
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-6,
'
enabled
'
:
templates_enabled
,
'
embed_angles
'
:
embed_template_torsion_angles
,
},
"
extra_msa
"
:
{
"
extra_msa_embedder
"
:
{
"
c_in
"
:
25
,
"
c_out
"
:
c_e
,
'
extra_msa
'
:
{
'
extra_msa_embedder
'
:
{
'
c_in
'
:
25
,
'
c_out
'
:
c_e
,
},
"
extra_msa_stack
"
:
{
"
c_m
"
:
c_e
,
"
c_z
"
:
c_z
,
"
c_hidden_msa_att
"
:
8
,
"
c_hidden_opm
"
:
32
,
"
c_hidden_mul
"
:
128
,
"
c_hidden_pair_att
"
:
32
,
"
no_heads_msa
"
:
8
,
"
no_heads_pair
"
:
4
,
"
no_blocks
"
:
4
,
"
transition_n
"
:
4
,
"
msa_dropout
"
:
0.15
,
"
pair_dropout
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-10,
'
extra_msa_stack
'
:
{
'
c_m
'
:
c_e
,
'
c_z
'
:
c_z
,
'
c_hidden_msa_att
'
:
8
,
'
c_hidden_opm
'
:
32
,
'
c_hidden_mul
'
:
128
,
'
c_hidden_pair_att
'
:
32
,
'
no_heads_msa
'
:
8
,
'
no_heads_pair
'
:
4
,
'
no_blocks
'
:
4
,
'
transition_n
'
:
4
,
'
msa_dropout
'
:
0.15
,
'
pair_dropout
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-10,
},
"
enabled
"
:
True
,
'
enabled
'
:
True
,
},
"
evoformer_stack
"
:
{
"
c_m
"
:
c_m
,
"
c_z
"
:
c_z
,
"
c_hidden_msa_att
"
:
32
,
"
c_hidden_opm
"
:
32
,
"
c_hidden_mul
"
:
128
,
"
c_hidden_pair_att
"
:
32
,
"
c_s
"
:
c_s
,
"
no_heads_msa
"
:
8
,
"
no_heads_pair
"
:
4
,
"
no_blocks
"
:
48
,
"
transition_n
"
:
4
,
"
msa_dropout
"
:
0.15
,
"
pair_dropout
"
:
0.25
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
"
chunk_size
"
:
chunk_size
,
"
inf
"
:
1e5
,
#1e9,
"
eps
"
:
eps
,
#1e-10,
'
evoformer_stack
'
:
{
'
c_m
'
:
c_m
,
'
c_z
'
:
c_z
,
'
c_hidden_msa_att
'
:
32
,
'
c_hidden_opm
'
:
32
,
'
c_hidden_mul
'
:
128
,
'
c_hidden_pair_att
'
:
32
,
'
c_s
'
:
c_s
,
'
no_heads_msa
'
:
8
,
'
no_heads_pair
'
:
4
,
'
no_blocks
'
:
48
,
'
transition_n
'
:
4
,
'
msa_dropout
'
:
0.15
,
'
pair_dropout
'
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
'
inf
'
:
1e5
,
#1e9,
'
eps
'
:
eps
,
#1e-10,
},
"
structure_module
"
:
{
"
c_s
"
:
c_s
,
"
c_z
"
:
c_z
,
"
c_ipa
"
:
16
,
"
c_resnet
"
:
128
,
"
no_heads_ipa
"
:
12
,
"
no_qk_points
"
:
4
,
"
no_v_points
"
:
8
,
"
dropout_rate
"
:
0.1
,
"
no_blocks
"
:
8
,
"
no_transition_layers
"
:
1
,
"
no_resnet_blocks
"
:
2
,
"
no_angles
"
:
7
,
"
trans_scale_factor
"
:
10
,
"
epsilon
"
:
eps
,
#1e-12,
"
inf
"
:
1e5
,
'
structure_module
'
:
{
'
c_s
'
:
c_s
,
'
c_z
'
:
c_z
,
'
c_ipa
'
:
16
,
'
c_resnet
'
:
128
,
'
no_heads_ipa
'
:
12
,
'
no_qk_points
'
:
4
,
'
no_v_points
'
:
8
,
'
dropout_rate
'
:
0.1
,
'
no_blocks
'
:
8
,
'
no_transition_layers
'
:
1
,
'
no_resnet_blocks
'
:
2
,
'
no_angles
'
:
7
,
'
trans_scale_factor
'
:
10
,
'
epsilon
'
:
eps
,
#1e-12,
'
inf
'
:
1e5
,
},
"
heads
"
:
{
"
lddt
"
:
{
"
no_bins
"
:
50
,
"
c_in
"
:
c_s
,
"
c_hidden
"
:
128
,
'
heads
'
:
{
'
lddt
'
:
{
'
no_bins
'
:
50
,
'
c_in
'
:
c_s
,
'
c_hidden
'
:
128
,
},
"
distogram
"
:
{
"
c_z
"
:
c_z
,
"
no_bins
"
:
aux_distogram_bins
,
'
distogram
'
:
{
'
c_z
'
:
c_z
,
'
no_bins
'
:
aux_distogram_bins
,
},
"
tm
"
:
{
"
c_z
"
:
c_z
,
"
no_bins
"
:
aux_distogram_bins
,
"
enabled
"
:
False
,
'
tm
'
:
{
'
c_z
'
:
c_z
,
'
no_bins
'
:
aux_distogram_bins
,
'
enabled
'
:
False
,
},
"
masked_msa
"
:
{
"
c_m
"
:
c_m
,
"
c_out
"
:
23
,
'
masked_msa
'
:
{
'
c_m
'
:
c_m
,
'
c_out
'
:
23
,
},
"
experimentally_resolved
"
:
{
"
c_s
"
:
c_s
,
"
c_out
"
:
37
,
'
experimentally_resolved
'
:
{
'
c_s
'
:
c_s
,
'
c_out
'
:
37
,
},
},
},
"
relax
"
:
{
"
max_iterations
"
:
0
,
# no max
"
tolerance
"
:
2.39
,
"
stiffness
"
:
10.0
,
"
max_outer_iterations
"
:
20
,
"
exclude_residues
"
:
[],
'
relax
'
:
{
'
max_iterations
'
:
0
,
# no max
'
tolerance
'
:
2.39
,
'
stiffness
'
:
10.0
,
'
max_outer_iterations
'
:
20
,
'
exclude_residues
'
:
[],
},
"
loss
"
:
{
"
distogram
"
:
{
"
min_bin
"
:
2.3125
,
"
max_bin
"
:
21.6875
,
"
no_bins
"
:
64
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
0.3
,
'
loss
'
:
{
'
distogram
'
:
{
'
min_bin
'
:
2.3125
,
'
max_bin
'
:
21.6875
,
'
no_bins
'
:
64
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
0.3
,
},
"
experimentally_resolved
"
:
{
"
eps
"
:
eps
,
#1e-8,
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
weight
"
:
0.
,
'
experimentally_resolved
'
:
{
'
eps
'
:
eps
,
#1e-8,
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
weight
'
:
0.
,
},
"
fape
"
:
{
"
backbone
"
:
{
"
clamp_distance
"
:
10.
,
"
loss_unit_distance
"
:
10.
,
"
weight
"
:
0.5
,
'
fape
'
:
{
'
backbone
'
:
{
'
clamp_distance
'
:
10.
,
'
loss_unit_distance
'
:
10.
,
'
weight
'
:
0.5
,
},
"
sidechain
"
:
{
"
clamp_distance
"
:
10.
,
"
length_scale
"
:
10.
,
"
weight
"
:
0.5
,
'
sidechain
'
:
{
'
clamp_distance
'
:
10.
,
'
length_scale
'
:
10.
,
'
weight
'
:
0.5
,
},
"
eps
"
:
1e-4
,
"
weight
"
:
1.0
,
'
eps
'
:
1e-4
,
'
weight
'
:
1.0
,
},
"
lddt
"
:
{
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
cutoff
"
:
15.
,
"
no_bins
"
:
50
,
"
eps
"
:
eps
,
#1e-10,
"
weight
"
:
0.01
,
'
lddt
'
:
{
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
cutoff
'
:
15.
,
'
no_bins
'
:
50
,
'
eps
'
:
eps
,
#1e-10,
'
weight
'
:
0.01
,
},
"
masked_msa
"
:
{
"
eps
"
:
eps
,
#1e-8,
"
weight
"
:
2.0
,
'
masked_msa
'
:
{
'
eps
'
:
eps
,
#1e-8,
'
weight
'
:
2.0
,
},
"
supervised_chi
"
:
{
"
chi_weight
"
:
0.5
,
"
angle_norm_weight
"
:
0.01
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
1.0
,
'
supervised_chi
'
:
{
'
chi_weight
'
:
0.5
,
'
angle_norm_weight
'
:
0.01
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
1.0
,
},
"
violation
"
:
{
"
violation_tolerance_factor
"
:
12.0
,
"
clash_overlap_tolerance
"
:
1.5
,
"
eps
"
:
eps
,
#1e-6,
"
weight
"
:
0.
,
'
violation
'
:
{
'
violation_tolerance_factor
'
:
12.0
,
'
clash_overlap_tolerance
'
:
1.5
,
'
eps
'
:
eps
,
#1e-6,
'
weight
'
:
0.
,
},
"
tm
"
:
{
"
max_bin
"
:
31
,
"
no_bins
"
:
64
,
"
min_resolution
"
:
0.1
,
"
max_resolution
"
:
3.0
,
"
eps
"
:
eps
,
#1e-8,
"
weight
"
:
0.
,
'
tm
'
:
{
'
max_bin
'
:
31
,
'
no_bins
'
:
64
,
'
min_resolution
'
:
0.1
,
'
max_resolution
'
:
3.0
,
'
eps
'
:
eps
,
#1e-8,
'
weight
'
:
0.
,
},
"eps"
:
eps
,
'eps'
:
eps
,
},
'ema'
:
{
'decay'
:
0.999
},
})
openfold/features/
np/
data_pipeline.py
→
openfold/features/data_pipeline.py
View file @
d48c052c
import
os
import
datetime
import
numpy
as
np
from
typing
import
Mapping
,
Optional
,
Sequence
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
from
openfold.features
import
templates
,
parsers
from
openfold.features
import
templates
,
parsers
,
mmcif_parsing
from
openfold.features.np
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
features
=
{}
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
...
...
@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x
=
True
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'domain_name'
]
=
np
.
array
([
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'domain_name'
]
=
np
.
array
(
[
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'sequence'
]
=
np
.
array
([
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'sequence'
]
=
np
.
array
(
[
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
features
def
make_mmcif_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
)
->
FeatureDict
:
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
description
=
'_'
.
join
([
mmcif_object
.
file_id
,
chain_id
])
num_res
=
len
(
input_sequence
)
mmcif_feats
=
{}
mmcif_feats
.
update
(
make_sequence_features
(
sequence
=
input_sequence
,
description
=
description
,
num_res
=
num_res
,
))
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
)
mmcif_feats
[
"all_atom_positions"
]
=
all_atom_positions
mmcif_feats
[
"all_atom_mask"
]
=
all_atom_mask
mmcif_feats
[
"resolution"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"resolution"
]],
dtype
=
np
.
float32
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
mmcif_feats
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
...
...
@@ -58,9 +101,9 @@ def make_msa_features(
)
return
features
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
class
AlignmentRunner
:
""" Runs alignment tools and saves the results """
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
...
...
@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
pdb70_database_path
:
str
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
mgnify_max_hits
:
int
=
501
,
uniref_max_hits
:
int
=
10000
no_cpus
:
int
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
):
"""Constructs a feature dict for a given FASTA file."""
self
.
_use_small_bfd
=
use_small_bfd
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
database_path
=
small_bfd_database_path
,
n_cpu
=
no_cpus
,
)
else
:
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uniclust30_database_path
]
databases
=
[
bfd_database_path
,
uniclust30_database_path
],
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
database_path
=
mgnify_database_path
,
n_cpu
=
no_cpus
,
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
)
self
.
template_featurizer
=
template_featurizer
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
FeatureDict
:
"""Runs alignment tools on the input sequence and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
input_fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
input_fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
input_fasta_path
)[
0
]
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'sto'
],
max_sequences
=
self
.
uniref_max_hits
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
uniref90_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniref90_hits.sto'
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_uniref90_result
[
'sto'
]
)
f
.
write
(
uniref90_msa_as_a3m
)
mgnify_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'mgnify_hits.so'
)
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
'sto'
],
max_sequences
=
self
.
mgnify_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_mgnify_result
[
'sto'
]
)
f
.
write
(
mgnify_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'pdb70_hits.hhr'
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
f
.
write
(
hhsearch_result
)
uniref90_msa
,
uniref90_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_uniref90_result
[
'sto'
]
)
mgnify_msa
,
mgnify_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_mgnify_result
[
'sto'
]
)
hhsearch_hits
=
parsers
.
parse_hhr
(
hhsearch_result
)
mgnify_msa
=
mgnify_msa
[:
self
.
mgnify_max_hits
]
mgnify_deletion_matrix
=
mgnify_deletion_matrix
[:
self
.
mgnify_max_hits
]
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
input_
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
msa_
output_dir
,
'small_bfd_hits.
a3m
'
)
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'small_bfd_hits.
sto
'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
if
(
output_dir
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
]
class
DataPipeline
:
"""Assembles input features."""
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_alignment_output
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'r'
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'r'
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
input_fasta_path
)
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
hhblits_bfd_uniclust_result
[
'a3m'
]
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'r'
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
()
)
if
(
self
.
use_small_bfd
):
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'small_bfd_hits.sto'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
)
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
return
{
'uniref90_msa'
:
uniref90_msa
,
'uniref90_deletion_matrix'
:
uniref90_deletion_matrix
,
'mgnify_msa'
:
mgnify_msa
,
'mgnify_deletion_matrix'
:
mgnify_deletion_matrix
,
'hhsearch_hits'
:
hhsearch_hits
,
'bfd_msa'
:
bfd_msa
,
'bfd_deletion_matrix'
:
bfd_deletion_matrix
,
}
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
hhsearch_hits
hits
=
alignments
[
'
hhsearch_hits
'
]
)
sequence_features
=
make_sequence_features
(
...
...
@@ -180,9 +275,62 @@ class DataPipeline:
)
msa_features
=
make_msa_features
(
msas
=
(
uniref90_msa
,
bfd_msa
,
mgnify_msa
),
deletion_matrices
=
(
uniref90_deletion_matrix
,
bfd_deletion_matrix
,
mgnify_deletion_matrix
)
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
def
process_mmcif
(
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if
(
chain_id
is
None
):
chains
=
mmcif
.
structure
.
get_chains
()
chain
=
next
(
chains
,
None
)
if
(
chain
is
None
):
raise
ValueError
(
'No chains in mmCIF file'
)
chain_id
=
chain
.
id
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
'hhsearch_hits'
]
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
openfold/features/data_transforms.py
View file @
d48c052c
...
...
@@ -6,8 +6,10 @@ import torch
from
operator
import
add
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
,
batched_gather
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
...
...
@@ -59,7 +61,7 @@ def fix_templates_aatype(protein):
num_templates
=
protein
[
'template_aatype'
].
shape
[
0
]
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
r
esidue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
r
c
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
num_templates
,
-
1
)
...
...
@@ -69,8 +71,8 @@ def fix_templates_aatype(protein):
return
protein
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as r
esidue_constants
."""
new_order_list
=
r
esidue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
"""Correct MSA restype to have the same order as r
c
."""
new_order_list
=
r
c
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
'msa'
].
shape
[
1
],
dtype
=
protein
[
'msa'
].
dtype
).
transpose
(
0
,
1
)
...
...
@@ -93,7 +95,7 @@ def squeeze_features(protein):
for
k
in
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask
s
'
]:
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask'
]:
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
...
...
@@ -104,12 +106,6 @@ def squeeze_features(protein):
protein
[
k
]
=
protein
[
k
][
0
]
return
protein
def
make_protein_crop_to_size_seed
(
protein
):
protein
[
'random_crop_to_size_seed'
]
=
torch
.
distributions
.
Uniform
(
low
=
torch
.
int32
,
high
=
torch
.
int32
).
sample
((
2
)
)
return
protein
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a portion of the MSA with 'X'."""
...
...
@@ -284,19 +280,19 @@ def make_msa_mask(protein):
protein
[
'msa_row_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
s
):
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
):
"""Create pseudo beta features."""
is_gly
=
torch
.
eq
(
aatype
,
r
esidue_constants
.
restype_order
[
'G'
])
ca_idx
=
r
esidue_constants
.
atom_order
[
'CA'
]
cb_idx
=
r
esidue_constants
.
atom_order
[
'CB'
]
is_gly
=
torch
.
eq
(
aatype
,
r
c
.
restype_order
[
'G'
])
ca_idx
=
r
c
.
atom_order
[
'CA'
]
cb_idx
=
r
c
.
atom_order
[
'CB'
]
pseudo_beta
=
torch
.
where
(
torch
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
if
all_atom_mask
s
is
not
None
:
if
all_atom_mask
is
not
None
:
pseudo_beta_mask
=
torch
.
where
(
is_gly
,
all_atom_mask
s
[...,
ca_idx
],
all_atom_mask
s
[...,
cb_idx
])
is_gly
,
all_atom_mask
[...,
ca_idx
],
all_atom_mask
[...,
cb_idx
])
return
pseudo_beta
,
pseudo_beta_mask
else
:
return
pseudo_beta
...
...
@@ -307,9 +303,9 @@ def make_pseudo_beta(protein, prefix=''):
assert
prefix
in
[
''
,
'template_'
]
protein
[
prefix
+
'pseudo_beta'
],
protein
[
prefix
+
'pseudo_beta_mask'
]
=
(
pseudo_beta_fn
(
protein
[
'template_aatype'
if
prefix
else
'
all_atom_
aatype'
],
protein
[
'template_aatype'
if
prefix
else
'aatype'
],
protein
[
prefix
+
'all_atom_positions'
],
protein
[
'template_all_atom_mask
s
'
if
prefix
else
'all_atom_mask'
]))
protein
[
'template_all_atom_mask'
if
prefix
else
'all_atom_mask'
]))
return
protein
@
curry1
...
...
@@ -456,10 +452,12 @@ def make_msa_feat(protein):
protein
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
@
curry1
def
select_feat
(
protein
,
feature_list
):
return
{
k
:
v
for
k
,
v
in
protein
.
items
()
if
k
in
feature_list
}
@
curry1
def
crop_templates
(
protein
,
max_templates
):
for
k
,
v
in
protein
.
items
():
...
...
@@ -467,72 +465,74 @@ def crop_templates(protein, max_templates):
protein
[
k
]
=
v
[:
max_templates
]
return
protein
def
make_atom14_masks
(
protein
):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
restype_atom37_to_atom14
=
[]
restype_atom14_mask
=
[]
for
rt
in
r
esidue_constants
.
restypes
:
atom_names
=
r
esidue_constants
.
restype_name_to_atom14_names
[
r
esidue_constants
.
restype_1to3
[
rt
]
for
rt
in
r
c
.
restypes
:
atom_names
=
r
c
.
restype_name_to_atom14_names
[
r
c
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
(
r
esidue_constants
.
atom_order
[
name
]
if
name
else
0
)
(
r
c
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
r
esidue_constants
.
atom_types
for
name
in
r
c
.
atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
,
device
=
protein
[
'aatype'
].
device
,
)
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
,
device
=
protein
[
'aatype'
].
device
,
)
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
torch
.
float32
restype_atom14_mask
,
dtype
=
torch
.
float32
,
device
=
protein
[
'aatype'
].
device
,
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
torch
.
index_select
(
restype_atom14_to_atom37
,
0
,
protein
[
'aatype'
]
)
residx_atom14_mask
=
torch
.
index_select
(
restype_atom14_mask
,
0
,
protein
[
'aatype'
]
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
'aatype'
]]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
'aatype'
]]
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
.
long
()
# create the gather indices for mapping back
residx_atom37_to_atom14
=
torch
.
index_select
(
restype_atom37_to_atom14
,
0
,
protein
[
'aatype'
]
)
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
'aatype'
]]
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
torch
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
restype_atom37_mask
=
torch
.
zeros
(
[
21
,
37
],
dtype
=
torch
.
float32
,
device
=
protein
[
'aatype'
].
device
)
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
restype_name
=
rc
.
restype_1to3
[
restype_letter
]
atom_names
=
rc
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
r
esidue_constants
.
atom_order
[
atom_name
]
atom_type
=
r
c
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
torch
.
index_select
(
restype_atom37_mask
,
0
,
protein
[
'aatype'
]
)
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
'aatype'
]]
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
...
...
@@ -543,3 +543,546 @@ def make_atom14_masks_np(batch):
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
def
make_atom14_positions
(
protein
):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask
=
protein
[
"atom14_atom_exists"
]
residx_atom14_to_atom37
=
protein
[
"residx_atom14_to_atom37"
]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask
=
residx_atom14_mask
*
batched_gather
(
protein
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
dim
=-
1
,
no_batch_dims
=
len
(
protein
[
"all_atom_mask"
].
shape
[:
-
1
])
)
# Gather the ground truth positions.
residx_atom14_gt_positions
=
residx_atom14_gt_mask
[...,
None
]
*
(
batched_gather
(
protein
[
"all_atom_positions"
],
residx_atom14_to_atom37
,
dim
=-
2
,
no_batch_dims
=
len
(
protein
[
"all_atom_positions"
].
shape
[:
-
2
])
)
)
protein
[
"atom14_atom_exists"
]
=
residx_atom14_mask
protein
[
"atom14_gt_exists"
]
=
residx_atom14_gt_mask
protein
[
"atom14_gt_positions"
]
=
residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3
=
[
rc
.
restype_1to3
[
res
]
for
res
in
rc
.
restypes
]
restype_3
+=
[
"UNK"
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
torch
.
eye
(
14
,
dtype
=
protein
[
"all_atom_mask"
].
dtype
,
device
=
protein
[
"all_atom_mask"
].
device
)
for
res
in
restype_3
}
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
torch
.
arange
(
14
,
device
=
protein
[
"all_atom_mask"
].
device
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
protein
[
"all_atom_mask"
].
new_zeros
((
14
,
14
))
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform
=
renaming_matrices
[
protein
[
"aatype"
]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions
=
torch
.
einsum
(
"...rac,...rab->...rbc"
,
residx_atom14_gt_positions
,
renaming_transform
)
protein
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask
=
torch
.
einsum
(
"...ra,...rab->...rb"
,
residx_atom14_gt_mask
,
renaming_transform
)
protein
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous
=
protein
[
"all_atom_mask"
].
new_zeros
((
21
,
14
))
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]]
atom_idx1
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# From this create an ambiguous_mask for the given sequence.
protein
[
"atom14_atom_is_ambiguous"
]
=
(
restype_atom14_is_ambiguous
[
protein
[
"aatype"
]]
)
return
protein
def
atom37_to_frames
(
protein
):
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
all_atom_mask
.
new_zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
)
restype_rigidgroup_base_atom37_idx
=
aatype
.
new_tensor
(
restype_rigidgroup_base_atom37_idx
,
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
)
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
1e-8
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
residx_rigidgroup_is_ambiguous
=
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
protein
[
'rigidgroups_gt_frames'
]
=
gt_frames_tensor
protein
[
'rigidgroups_gt_exists'
]
=
gt_exists
protein
[
'rigidgroups_group_exists'
]
=
group_exists
protein
[
'rigidgroups_group_is_ambiguous'
]
=
residx_rigidgroup_is_ambiguous
protein
[
'rigidgroups_alt_gt_frames'
]
=
alt_gt_frames_tensor
return
protein
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
chi_atom_indices
@
curry1
def
atom37_to_torsion_angles
(
protein
,
prefix
=
''
,
):
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Dict containing:
* (prefix)aatype:
[*, N_res] residue indices
* (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
* (prefix)all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype
=
protein
[
prefix
+
"aatype"
]
all_atom_positions
=
protein
[
prefix
+
"all_atom_positions"
]
all_atom_mask
=
protein
[
prefix
+
"all_atom_mask"
]
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_positions
.
new_zeros
(
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
phi_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
psi_atom_pos
=
torch
.
cat
(
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
pre_omega_mask
=
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
chi_atom_indices
=
torch
.
as_tensor
(
get_chi_atom_indices
(),
device
=
aatype
.
device
)
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
all_atom_positions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chi_angle_atoms_mask
=
batched_gather
(
all_atom_mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
[
pre_omega_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
],
dim
=-
3
)
torsion_angles_mask
=
torch
.
cat
(
[
pre_omega_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
],
dim
=-
1
)
torsion_frames
=
T
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
1e-8
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
torsions_atom_pos
[...,
3
,
:]
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
1e-8
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
rc
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
)
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[...,
None
]
)
protein
[
prefix
+
"torsion_angles_sin_cos"
]
=
torsion_angles_sin_cos
protein
[
prefix
+
"alt_torsion_angles_sin_cos"
]
=
alt_torsion_angles_sin_cos
protein
[
prefix
+
"torsion_angles_mask"
]
=
torsion_angles_mask
return
protein
def
get_backbone_frames
(
protein
):
# TODO: Verify that this is correct
protein
[
"backbone_affine_tensor"
]
=
(
protein
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
protein
[
"backbone_affine_mask"
]
=
(
protein
[
"rigidgroups_gt_exists"
][...,
0
]
)
return
protein
def
get_chi_angles
(
protein
):
dtype
=
protein
[
"all_atom_mask"
].
dtype
protein
[
"chi_angles_sin_cos"
]
=
(
protein
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
dtype
)
protein
[
"chi_mask"
]
=
protein
[
"torsion_angles_mask"
][...,
3
:].
to
(
dtype
)
return
protein
@
curry1
def
random_crop_to_size
(
protein
,
crop_size
,
max_templates
,
shape_schema
,
subsample_templates
=
False
,
seed
=
None
,
batch_mode
=
'clamped'
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
'seq_length'
]
if
'template_mask'
in
protein
:
num_templates
=
protein
[
'template_mask'
].
shape
[
-
1
]
else
:
num_templates
=
protein
[
'aatype'
].
new_zeros
((
1
,))
num_res_crop_size
=
min
(
seq_length
,
crop_size
)
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
'seq_length'
].
device
)
if
(
seed
is
not
None
):
g
.
manual_seed
(
seed
)
def
_randint
(
lower
,
upper
):
return
int
(
torch
.
randint
(
lower
,
upper
,
(
1
,),
device
=
protein
[
'seq_length'
].
device
,
generator
=
g
)[
0
])
if
subsample_templates
:
templates_crop_start
=
_randint
(
0
,
num_templates
+
1
)
templates_select_indices
=
torch
.
randperm
(
num_templates
,
device
=
protein
[
'seq_length'
].
device
,
generator
=
g
)
num_templates_crop_size
=
min
(
num_templates
-
templates_crop_start
,
max_templates
)
else
:
templates_crop_start
=
0
num_templates_crop_size
=
num_templates
n
=
seq_length
-
num_res_crop_size
if
(
batch_mode
==
'clamped'
):
right_anchor
=
n
+
1
elif
(
batch_mode
==
'unclamped'
):
x
=
_randint
(
0
,
n
)
right_anchor
=
n
-
x
+
1
else
:
raise
ValueError
(
"Invalid batch mode"
)
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
for
k
,
v
in
protein
.
items
():
if
(
k
not
in
shape_schema
or
(
'template'
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
])
):
continue
# randomly permute the templates before cropping them.
if
k
.
startswith
(
'template'
)
and
subsample_templates
:
v
=
v
[
templates_select_indices
]
slices
=
[]
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
is_num_res
=
(
dim_size
==
NUM_RES
)
if
i
==
0
and
k
.
startswith
(
'template'
):
crop_size
=
num_templates_crop_size
crop_start
=
templates_crop_start
else
:
crop_start
=
num_res_crop_start
if
is_num_res
else
0
crop_size
=
num_res_crop_size
if
is_num_res
else
dim
slices
.
append
(
slice
(
crop_start
,
crop_start
+
crop_size
))
protein
[
k
]
=
v
[
slices
]
protein
[
'seq_length'
]
=
(
protein
[
'seq_length'
].
new_tensor
(
num_res_crop_size
)
)
return
protein
openfold/features/feature_pipeline.py
View file @
d48c052c
...
...
@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
return
tensor_dict
def
make_data_config
(
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
num_res
:
int
,
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
.
data
)
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
)
mode_cfg
=
cfg
[
mode
]
with
cfg
.
unlocked
():
if
(
mode_cfg
.
crop_size
is
None
):
mode_cfg
.
crop_size
=
num_res
feature_names
=
cfg
.
common
.
unsupervised_features
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
with
cfg
.
unlock
ed
(
):
cfg
.
eval
.
crop_size
=
num_
res
if
(
cfg
[
mode
].
supervis
ed
):
feature_names
+=
cfg
.
common
.
supervised_featu
res
return
cfg
,
feature_names
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
random_seed
:
int
=
0
):
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
np_example
:
np_example
[
'deletion_matrix'
]
=
(
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
))
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
)
)
if
batch_mode
==
'clamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
1.
).
astype
(
np
.
float32
)
)
elif
batch_mode
==
'unclamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
0.
).
astype
(
np
.
float32
)
)
torch
.
manual_seed
(
random_seed
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
np_example
=
np_example
,
features
=
feature_names
)
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
...
@@ -70,10 +98,13 @@ class FeaturePipeline:
self
.
params
=
params
def
process_features
(
self
,
raw_features
:
FeatureDict
,
random_seed
:
int
)
->
FeatureDict
:
raw_features
:
FeatureDict
,
mode
:
str
=
'train'
,
batch_mode
:
str
=
'clamped'
,
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
random_seed
=
random_seed
)
\ No newline at end of file
mode
=
mode
,
batch_mode
=
batch_mode
,
)
openfold/features/input_pipeline.py
View file @
d48c052c
from
functools
import
partial
import
torch
from
openfold.features
import
data_transforms
def
nonensembled_transform_fns
(
data_confi
g
):
def
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
):
"""Input pipeline data transformers that are not ensembled."""
common_cfg
=
data_config
.
common
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
correct_msa_restypes
,
...
...
@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
if
(
common_cfg
.
use_template_torsion_angles
):
transforms
.
extend
([
data_transforms
.
atom37_to_torsion_angles
(
'template_'
),
])
transforms
.
extend
([
data_transforms
.
make_atom14_masks
,
])
if
(
mode_cfg
.
supervised
):
transforms
.
extend
([
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
''
),
data_transforms
.
make_pseudo_beta
(
''
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
])
return
transforms
def
ensembled_transform_fns
(
data_config
):
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
):
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg
=
data_config
.
common
eval_cfg
=
data_config
.
eval
transforms
=
[]
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
-
eval
_cfg
.
max_templates
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
-
mode
_cfg
.
max_templates
else
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
...
...
@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
eval_cfg
.
masked_msa_replace_fraction
)
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
)
)
if
common_cfg
.
msa_cluster_features
:
...
...
@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms
.
append
(
data_transforms
.
make_msa_feat
())
crop_feats
=
dict
(
eval
_cfg
.
feat
)
crop_feats
=
dict
(
common
_cfg
.
feat
)
if
eval
_cfg
.
fixed_size
:
if
mode
_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
torch
.
Generator
().
seed
()
))
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
eval
_cfg
.
crop_size
,
eval
_cfg
.
max_templates
mode
_cfg
.
crop_size
,
mode
_cfg
.
max_templates
))
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
eval_cfg
.
max_templates
))
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
)
return
transforms
def
process_tensors_from_config
(
tensors
,
data_config
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
'clamped'
):
"""Based on the config, apply filters and transformations to the data."""
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
data_config
)
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fn
=
compose
(
fns
)
d
[
'ensemble_index'
]
=
i
return
fn
(
d
)
eval_cfg
=
data_config
.
eval
tensors
=
compose
(
nonensembled_transform_fns
(
data_confi
g
)
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
)
)(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
eval
_cfg
.
num_ensemble
if
data_config
.
common
.
resample_msa_in_recycling
:
num_ensemble
=
mode
_cfg
.
num_ensemble
if
common
_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
data_config
.
common
.
num_recycle
+
1
num_ensemble
*=
common
_cfg
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
...
...
@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
([
dict_i
[
feat
]
for
dict_i
in
ensembles
])
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/features/mmcif_parsing.py
View file @
d48c052c
"""Parses the mmCIF file format."""
import
collections
import
dataclasses
import
io
import
json
import
logging
import
os
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
from
Bio
import
PDB
from
Bio.Data
import
SCOPData
import
numpy
as
np
import
openfold.np.residue_constants
as
residue_constants
# Type aliases:
ChainId
=
str
...
...
@@ -369,3 +374,73 @@ def _get_protein_chains(
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
# Extract the coordinates
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
chain_id
])
all_atom_positions
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
all_atom_mask
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coords of the selenium atom in the sulphur column
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Could not parse
{
f
}
. Skipping...'
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
'release_date'
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
'no_chains'
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
'w'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
openfold/features/np/hhsearch.py
View file @
d48c052c
...
...
@@ -18,6 +18,7 @@ class HHSearch:
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
n_cpu
:
int
=
2
,
maxseq
:
int
=
1_000_000
):
"""Initializes the Python HHsearch wrapper.
...
...
@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
...
...
@@ -34,6 +36,7 @@ class HHSearch:
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
self
.
n_cpu
=
n_cpu
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
...
...
@@ -56,7 +59,8 @@ class HHSearch:
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
)
'-maxseq'
,
str
(
self
.
maxseq
),
'-cpu'
,
str
(
self
.
n_cpu
),
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
...
...
openfold/features/np/jackhmmer.py
View file @
d48c052c
...
...
@@ -3,14 +3,12 @@
from
concurrent
import
futures
import
glob
import
logging
import
os
import
subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
absl
import
logging
from
openfold.features.np
import
utils
...
...
openfold/features/np/utils.py
View file @
d48c052c
"""Common utilities for data pipeline tools."""
import
contextlib
import
datetime
import
shutil
import
tempfile
import
time
...
...
@@ -25,3 +26,9 @@ def timing(msg: str):
yield
toc
=
time
.
time
()
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
def
to_date
(
s
:
str
):
return
datetime
.
datetime
(
year
=
int
(
s
[:
4
]),
month
=
int
(
s
[
5
:
7
]),
day
=
int
(
s
[
8
:
10
])
)
openfold/features/templates.py
View file @
d48c052c
...
...
@@ -2,16 +2,17 @@
import
dataclasses
import
datetime
import
glob
import
json
import
logging
import
os
import
re
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
import
numpy
as
np
from
openfold.features
import
parsers
,
mmcif_parsing
from
openfold.features.np
import
kalign
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
...
...
@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES
=
{
'template_aatype'
:
np
.
int64
,
'template_all_atom_mask
s
'
:
np
.
float32
,
'template_all_atom_mask'
:
np
.
float32
,
'template_all_atom_positions'
:
np
.
float32
,
'template_domain_names'
:
np
.
object
,
'template_sequence'
:
np
.
object
,
...
...
@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return
result
def
generate_release_dates_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
dates
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
path
=
os
.
path
.
join
(
mmcif_dir
,
f
)
with
open
(
path
,
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
. Skipping...'
)
continue
mmcif
=
mmcif
.
mmcif_object
release_date
=
mmcif
.
header
[
'release_date'
]
dates
[
file_id
]
=
release_date
with
open
(
out_path
,
'r'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
dates
))
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if
path
.
endswith
(
'txt'
):
release_dates
=
{}
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
pdb_id
,
date
=
line
.
split
(
':'
)
date
=
date
.
strip
()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates
[
pdb_id
.
strip
()]
=
datetime
.
datetime
(
year
=
int
(
date
[:
4
]),
month
=
int
(
date
[
5
:
7
]),
day
=
int
(
date
[
8
:
10
]))
return
release_dates
else
:
raise
ValueError
(
'Invalid format of the release date file %s.'
%
path
)
with
open
(
path
,
'r'
)
as
fp
:
data
=
json
.
load
(
fp
)
return
{
pdb
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
if
k
==
"release_date"
}
def
_assess_hhsearch_hit
(
hit
:
parsers
.
TemplateHit
,
...
...
@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
auth_chain_id
])
relevant_chains
=
[
c
for
c
in
mmcif_object
.
structure
.
get_chains
()
if
c
.
id
==
auth_chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
auth_chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
all_positions
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
,
3
])
all_positions_mask
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
int64
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
auth_chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coordinates of the selenium atom in the sulphur column.
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_positions
[
res_index
]
=
pos
all_positions_mask
[
res_index
]
=
mask
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
)
all_atom_positions
,
all_atom_mask
=
coords_with_mask
_check_residue_distances
(
all_positions
,
all_positions_mask
,
max_ca_ca_distance
)
return
all_positions
,
all_positions_mask
all_atom_positions
,
all_atom_mask
,
max_ca_ca_distance
)
return
all_atom_positions
,
all_atom_mask
def
_extract_template_features
(
...
...
@@ -579,7 +569,7 @@ def _extract_template_features(
return
(
{
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
'template_all_atom_mask
s
'
:
np
.
array
(
templates_all_atom_masks
),
'template_all_atom_mask'
:
np
.
array
(
templates_all_atom_masks
),
'template_sequence'
:
output_templates_sequence
.
encode
(),
'template_aatype'
:
np
.
array
(
templates_aatype
),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
...
...
openfold/model/model.py
View file @
d48c052c
...
...
@@ -19,7 +19,6 @@ import torch.nn as nn
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
atom37_to_torsion_angles
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
...
...
@@ -115,21 +114,16 @@ class AlphaFold(nn.Module):
batch
,
)
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_all_atom_positions"
],
#.float(),
single_template_feats
[
"template_all_atom_masks"
],
#.float(),
eps
=
self
.
config
.
template
.
eps
,
)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
single_template_feats
[
"template_aatype"
],
)
single_template_embeds
=
{}
if
(
self
.
config
.
template
.
embed_angles
):
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
...
...
@@ -145,11 +139,11 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
)
template_embeds
.
append
({
"angle"
:
a
,
"pair"
:
t
,
"torsion_mask"
:
angle_feats
[
"torsion_angles_mask"
]
single_template_embeds
.
update
({
"pair"
:
t
,
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
...
@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
return
{
"template_angle_embedding"
:
template_embeds
[
"angle"
],
ret
=
{}
if
(
self
.
config
.
template
.
embed_angles
):
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
,
"torsion_angles_mask"
:
template_embeds
[
"torsion_mask"
],
}
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Primary output dictionary
...
...
@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if
(
self
.
config
.
n
o_
cycle
s
>
1
):
if
(
self
.
config
.
n
um_re
cycle
>
0
):
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
...
...
@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
"template_"
in
k
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
...
...
@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
)
# [*, S, N]
torsion_angles_mask
=
template_
embeds
[
"
torsion_angles_mask"
]
torsion_angles_mask
=
feats
[
"
template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
)
...
...
@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
...
...
@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
n
o_
cycle
s
):
for
cycle_no
in
range
(
self
.
config
.
n
um_re
cycle
+
1
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
n
o_cycles
-
1
)
)
is_final_iter
=
(
cycle_no
==
self
.
config
.
n
um_recycle
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
):
...
...
openfold/utils/exponential_moving_average.py
View file @
d48c052c
...
...
@@ -29,14 +29,15 @@ class ExponentialMovingAverage:
self
.
decay
=
decay
def
_update_state_dict_
(
self
,
update
,
state_dict
):
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
with
torch
.
no_grad
():
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
...
...
openfold/utils/feats.py
View file @
d48c052c
...
...
@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return
pseudo_beta
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
chi_atom_indices
def
atom14_to_atom37
(
atom14
,
batch
):
atom37_data
=
batched_gather
(
atom14
,
...
...
@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return
atom37_data
def
atom37_to_torsion_angles
(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_positions
.
new_zeros
(
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
phi_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
psi_atom_pos
=
torch
.
cat
(
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
pre_omega_mask
=
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
chi_atom_indices
=
torch
.
as_tensor
(
get_chi_atom_indices
(),
device
=
aatype
.
device
)
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
all_atom_positions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chi_angle_atoms_mask
=
batched_gather
(
all_atom_mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
[
pre_omega_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
],
dim
=-
3
)
torsion_angles_mask
=
torch
.
cat
(
[
pre_omega_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
],
dim
=-
1
)
torsion_frames
=
T
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
torsions_atom_pos
[...,
3
,
:]
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
rc
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
)
def
build_template_angle_feat
(
template_feats
):
template_aatype
=
template_feats
[
"template_aatype"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[...,
None
]
)
return
{
"torsion_angles_sin_cos"
:
torsion_angles_sin_cos
,
"alt_torsion_angles_sin_cos"
:
alt_torsion_angles_sin_cos
,
"torsion_angles_mask"
:
torsion_angles_mask
,
}
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
all_atom_mask
.
new_zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
)
restype_rigidgroup_base_atom37_idx
=
aatype
.
new_tensor
(
restype_rigidgroup_base_atom37_idx
,
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
template_feats
[
"template_alt_torsion_angles_sin_cos"
]
)
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
residx_rigidgroup_is_ambiguous
=
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
return
{
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
'rigidgroups_gt_exists'
:
gt_exists
,
'rigidgroups_group_exists'
:
group_exists
,
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
'rigidgroups_alt_gt_frames'
:
alt_gt_frames_tensor
,
}
def
build_template_angle_feat
(
angle_feats
,
template_aatype
):
torsion_angles_sin_cos
=
angle_feats
[
"torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
angle_feats
[
"alt_torsion_angles_sin_cos"
]
torsion_angles_mask
=
angle_feats
[
"torsion_angles_mask"
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
template_angle_feat
=
torch
.
cat
(
[
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
...
...
@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
t_aa_masks
=
batch
[
"template_all_atom_mask
s
"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
...
...
@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return
batch
def
build_ambiguity_feats
(
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous
=
ambiguous_atoms
[
batch
[
"aatype"
],
...]
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
batch
[
"atom14_gt_positions"
][...,
None
,
:]
*
swap_mat
[...,
None
],
dim
=-
3
)
)
atom14_alt_gt_exists
=
(
torch
.
sum
(
batch
[
"atom14_gt_exists"
][...,
None
]
*
swap_mat
,
dim
=-
2
)
)
return
{
"atom14_atom_is_ambiguous"
:
atom14_atom_is_ambiguous
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
def
torsion_angles_to_frames
(
t
:
T
,
alpha
:
torch
.
Tensor
,
...
...
openfold/utils/loss.py
View file @
d48c052c
...
...
@@ -18,6 +18,7 @@ import ml_collections
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.distributions.bernoulli
import
Bernoulli
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
...
...
@@ -117,7 +118,9 @@ def compute_fape(
return
normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
# DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def
backbone_loss
(
backbone_affine_tensor
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
...
...
@@ -130,7 +133,7 @@ def backbone_loss(
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[...,
None
,
:],
...
...
@@ -142,7 +145,6 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
if
(
use_clamped_fape
is
not
None
):
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
...
...
@@ -157,12 +159,12 @@ def backbone_loss(
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
# Take the mean over the layer dimension
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
0
)
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
-
1
)
return
fape_loss
...
...
@@ -1453,7 +1455,7 @@ class AlphaFoldLoss(nn.Module):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
out
[
"violation"
]
=
find_structural_violations
(
batch
,
...
...
@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**
self
.
config
.
violation
,
)
if
(
"atom14_atom_is_ambiguous"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
build_ambiguity_feats
(
batch
))
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
batch
[
"backbone_affine_mask"
]
=
(
batch
[
"rigidgroups_gt_exists"
][...,
0
]
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
with
torch
.
no_grad
():
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
eps
=
self
.
config
.
eps
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:].
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
"distogram"
:
lambda
:
distogram_loss
(
...
...
run_pretrained_openfold.py
View file @
d48c052c
...
...
@@ -15,17 +15,17 @@
import
argparse
from
datetime
import
date
import
pickle
import
logging
import
os
# A hack to get OpenMM and PyTorch to peacefully coexist
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
import
pickle
import
random
import
sys
from
openfold.features
import
templates
,
feature_pipeline
from
openfold.features.np
import
data_pipeline
from
openfold.features
import
templates
,
feature_pipeline
,
data_pipeline
import
time
...
...
@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
)
MAX_TEMPLATE_HITS
=
20
from
scripts.utils
import
add_data_args
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
.
model
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
model
=
model
.
to
(
args
.
device
)
model
=
model
.
to
(
args
.
model_
device
)
# FEATURE COLLECTION AND PROCESSING
use_small_bfd
=
args
.
preset
==
"reduced_dbs"
num_ensemble
=
1
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
MAX_TEMPLATE_HITS
,
max_hits
=
args
.
max_template_hits
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
...
...
@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
data_processor
=
data_pipeline
.
DataPipeline
(
...
...
@@ -87,7 +89,7 @@ def main(args):
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
eval
.
num_ensemble
=
num_ensemble
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
...
...
@@ -95,7 +97,7 @@ def main(args):
if
not
os
.
path
.
exists
(
alignment_dir
):
os
.
makedirs
(
alignment_dir
)
print
(
"Generating features..."
)
logging
.
info
(
"Generating features..."
)
alignment_runner
.
run
(
args
.
fasta_path
,
alignment_dir
)
...
...
@@ -105,42 +107,20 @@ def main(args):
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
random_seed
feature_dict
,
mode
=
'predict'
,
)
for
k
,
v
in
processed_feature_dict
.
items
():
print
(
k
)
print
(
v
.
shape
)
print
(
"Executing model..."
)
logging
.
info
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_
device
)
for
k
,
v
in
batch
.
items
()
}
longs
=
[
"aatype"
,
"template_aatype"
,
"extra_msa"
,
"residx_atom37_to_atom14"
,
"residx_atom14_to_atom37"
,
"true_msa"
,
"residue_index"
,
]
for
l
in
longs
:
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
make_contig
=
lambda
t
:
t
.
contiguous
()
batch
=
tensor_tree_map
(
make_contig
,
batch
)
t
=
time
.
time
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
...
...
@@ -158,9 +138,7 @@ def main(args):
result
=
out
,
b_factors
=
plddt_b_factors
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"7"
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
)
...
...
@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction.
t
=
time
.
time
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
...
...
@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
)
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'--uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
required
=
True
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cpu"
,
"--
model_
device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
...
...
@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from
openfold/resources/params"""
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"""Number of CPUs to use to run alignment tools"""
)
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
...
...
scripts/build_deepspeed_config.py
View file @
d48c052c
...
...
@@ -15,6 +15,7 @@
import
argparse
import
json
parser
=
argparse
.
ArgumentParser
(
description
=
'''Outputs a DeepSpeed
configuration file to
stdout'''
)
...
...
scripts/precompute_alignments.py
0 → 100644
View file @
d48c052c
import
argparse
import
logging
import
os
import
tempfile
import
openfold.features.mmcif_parsing
as
mmcif_parsing
from
openfold.features.data_pipeline
import
AlignmentRunner
from
scripts.utils
import
add_data_args
def
main
(
args
):
# Build the alignment tool runner
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus
,
)
for
f
in
os
.
listdir
(
args
.
input_dir
):
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
is_mmcif
=
f
.
endswith
(
'.cif'
)
is_fasta
=
f
.
endswith
(
'.fasta'
)
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
seqs
=
{}
if
(
is_mmcif
):
with
open
(
path
,
'r'
)
as
fp
:
mmcif_str
=
fp
.
read
()
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_str
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
...'
)
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
else
:
continue
mmcif
=
mmcif
.
mmcif_object
for
k
,
v
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
k
])
seqs
[
chain_id
]
=
v
elif
(
is_fasta
):
with
open
(
path
,
'r'
)
as
fp
:
fasta_str
=
fp
.
read
()
input_seqs
,
_
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
msg
=
f
'More than one input_sequence found in
{
f
}
'
if
(
args
.
raise_errors
):
raise
ValueError
(
msg
)
else
:
logging
.
warning
(
msg
)
input_sequence
=
input_seqs
[
0
]
seqs
[
file_id
]
=
input_sequence
else
:
continue
for
name
,
seq
in
seqs
.
items
():
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
if
(
os
.
path
.
isdir
(
alignment_dir
)):
logging
.
info
(
f
'
{
f
}
has already been processed. Skipping...'
)
continue
os
.
makedirs
(
alignment_dir
)
if
(
not
is_fasta
):
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
alignment_runner
.
run
(
f
if
is_fasta
else
fasta_path
,
alignment_dir
)
if
(
not
is_fasta
):
os
.
remove
(
fasta_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_dir"
,
type
=
str
,
help
=
"Path to directory containing mmCIF and/or FASTA files"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
"Directory in which to output alignments"
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--raise_errors"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to crash on parsing errors"
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"Number of CPUs to use"
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/utils.py
0 → 100644
View file @
d48c052c
import
argparse
from
datetime
import
date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--max_template_hits'
,
type
=
int
,
default
=
20
,
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
tests/compare_utils.py
View file @
d48c052c
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4,"
import
importlib
import
pkgutil
import
sys
...
...
tests/test_feats.py
View file @
d48c052c
...
...
@@ -16,6 +16,7 @@ import torch
import
numpy
as
np
import
unittest
import
openfold.features.data_transforms
as
data_transforms
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
...
...
@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
feat
s
.
atom37_to_frames
(
eps
=
1e-8
,
**
batch
)
out_repro
=
data_transform
s
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment