Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2923 additions
and
2530 deletions
+2923
-2530
openfold/config.py
openfold/config.py
+374
-360
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+144
-135
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+442
-374
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+31
-30
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+59
-49
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+360
-320
openfold/data/parsers.py
openfold/data/parsers.py
+81
-57
openfold/data/templates.py
openfold/data/templates.py
+356
-216
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+146
-125
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+72
-59
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+200
-169
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+87
-76
openfold/data/tools/utils.py
openfold/data/tools/utils.py
+12
-12
openfold/model/__init__.py
openfold/model/__init__.py
+6
-5
openfold/model/dropout.py
openfold/model/dropout.py
+27
-24
openfold/model/embedders.py
openfold/model/embedders.py
+134
-133
openfold/model/evoformer.py
openfold/model/evoformer.py
+125
-114
openfold/model/heads.py
openfold/model/heads.py
+66
-61
openfold/model/model.py
openfold/model/model.py
+117
-115
openfold/model/msa.py
openfold/model/msa.py
+84
-96
No files found.
openfold/config.py
View file @
07e64267
...
@@ -4,50 +4,50 @@ import ml_collections as mlc
...
@@ -4,50 +4,50 @@ import ml_collections as mlc
def
set_inf
(
c
,
inf
):
def
set_inf
(
c
,
inf
):
for
k
,
v
in
c
.
items
():
for
k
,
v
in
c
.
items
():
if
(
isinstance
(
v
,
mlc
.
ConfigDict
)
)
:
if
isinstance
(
v
,
mlc
.
ConfigDict
):
set_inf
(
v
,
inf
)
set_inf
(
v
,
inf
)
elif
(
k
==
'
inf
'
)
:
elif
k
==
"
inf
"
:
c
[
k
]
=
inf
c
[
k
]
=
inf
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
c
=
copy
.
deepcopy
(
config
)
if
(
name
==
'
model_1
'
)
:
if
name
==
"
model_1
"
:
pass
pass
elif
(
name
==
'
model_2
'
)
:
elif
name
==
"
model_2
"
:
pass
pass
elif
(
name
==
'
model_3
'
)
:
elif
name
==
"
model_3
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
'
model_4
'
)
:
elif
name
==
"
model_4
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
'
model_5
'
)
:
elif
name
==
"
model_5
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
'
model_1_ptm
'
)
:
elif
name
==
"
model_1_ptm
"
:
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
'
model_2_ptm
'
)
:
elif
name
==
"
model_2_ptm
"
:
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
'
model_3_ptm
'
)
:
elif
name
==
"
model_3_ptm
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
'
model_4_ptm
'
)
:
elif
name
==
"
model_4_ptm
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
'
model_5_ptm
'
)
:
elif
name
==
"
model_5_ptm
"
:
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
else
:
else
:
raise
ValueError
(
'
Invalid model name
'
)
raise
ValueError
(
"
Invalid model name
"
)
if
(
train
)
:
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
c
.
globals
.
chunk_size
=
None
if
(
low_prec
)
:
if
low_prec
:
c
.
globals
.
eps
=
1e-4
c
.
globals
.
eps
=
1e-4
# If we want exact numerical parity with the original, inf can't be
# If we want exact numerical parity with the original, inf can't be
# a global constant
# a global constant
...
@@ -69,370 +69,384 @@ num_recycle = mlc.FieldReference(3, field_type=int)
...
@@ -69,370 +69,384 @@ num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
NUM_RES
=
'
num residues placeholder
'
NUM_RES
=
"
num residues placeholder
"
NUM_MSA_SEQ
=
'
msa placeholder
'
NUM_MSA_SEQ
=
"
msa placeholder
"
NUM_EXTRA_SEQ
=
'
extra msa placeholder
'
NUM_EXTRA_SEQ
=
"
extra msa placeholder
"
NUM_TEMPLATES
=
'
num templates placeholder
'
NUM_TEMPLATES
=
"
num templates placeholder
"
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
(
'data'
:
{
{
'common'
:
{
"data"
:
{
'batch_modes'
:
[(
'clamped'
,
0.9
),
(
'unclamped'
,
0.1
)],
"common"
:
{
'feat'
:
{
"batch_modes"
:
[(
"clamped"
,
0.9
),
(
"unclamped"
,
0.1
)],
'aatype'
:
[
NUM_RES
],
"feat"
:
{
'all_atom_mask'
:
[
NUM_RES
,
None
],
"aatype"
:
[
NUM_RES
],
'all_atom_positions'
:
[
NUM_RES
,
None
,
None
],
"all_atom_mask"
:
[
NUM_RES
,
None
],
'alt_chi_angles'
:
[
NUM_RES
,
None
],
"all_atom_positions"
:
[
NUM_RES
,
None
,
None
],
'atom14_alt_gt_exists'
:
[
NUM_RES
,
None
],
"alt_chi_angles"
:
[
NUM_RES
,
None
],
'atom14_alt_gt_positions'
:
[
NUM_RES
,
None
,
None
],
"atom14_alt_gt_exists"
:
[
NUM_RES
,
None
],
'atom14_atom_exists'
:
[
NUM_RES
,
None
],
"atom14_alt_gt_positions"
:
[
NUM_RES
,
None
,
None
],
'atom14_atom_is_ambiguous'
:
[
NUM_RES
,
None
],
"atom14_atom_exists"
:
[
NUM_RES
,
None
],
'atom14_gt_exists'
:
[
NUM_RES
,
None
],
"atom14_atom_is_ambiguous"
:
[
NUM_RES
,
None
],
'atom14_gt_positions'
:
[
NUM_RES
,
None
,
None
],
"atom14_gt_exists"
:
[
NUM_RES
,
None
],
'atom37_atom_exists'
:
[
NUM_RES
,
None
],
"atom14_gt_positions"
:
[
NUM_RES
,
None
,
None
],
'backbone_affine_mask'
:
[
NUM_RES
],
"atom37_atom_exists"
:
[
NUM_RES
,
None
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
,
None
],
"backbone_affine_mask"
:
[
NUM_RES
],
'bert_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"backbone_affine_tensor"
:
[
NUM_RES
,
None
,
None
],
'chi_angles'
:
[
NUM_RES
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'chi_mask'
:
[
NUM_RES
,
None
],
"chi_angles"
:
[
NUM_RES
,
None
],
'extra_deletion_value'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"chi_mask"
:
[
NUM_RES
,
None
],
'extra_has_deletion'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"extra_deletion_value"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"extra_has_deletion"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_mask'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
"extra_msa"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_row_mask'
:
[
NUM_EXTRA_SEQ
],
"extra_msa_mask"
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'is_distillation'
:
[],
"extra_msa_row_mask"
:
[
NUM_EXTRA_SEQ
],
'msa_feat'
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
"is_distillation"
:
[],
'msa_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"msa_feat"
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
'msa_row_mask'
:
[
NUM_MSA_SEQ
],
"msa_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'pseudo_beta'
:
[
NUM_RES
,
None
],
"msa_row_mask"
:
[
NUM_MSA_SEQ
],
'pseudo_beta_mask'
:
[
NUM_RES
],
"pseudo_beta"
:
[
NUM_RES
,
None
],
'residue_index'
:
[
NUM_RES
],
"pseudo_beta_mask"
:
[
NUM_RES
],
'residx_atom14_to_atom37'
:
[
NUM_RES
,
None
],
"residue_index"
:
[
NUM_RES
],
'residx_atom37_to_atom14'
:
[
NUM_RES
,
None
],
"residx_atom14_to_atom37"
:
[
NUM_RES
,
None
],
'resolution'
:
[],
"residx_atom37_to_atom14"
:
[
NUM_RES
,
None
],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
"resolution"
:
[],
'rigidgroups_group_exists'
:
[
NUM_RES
,
None
],
"rigidgroups_alt_gt_frames"
:
[
NUM_RES
,
None
,
None
,
None
],
'rigidgroups_group_is_ambiguous'
:
[
NUM_RES
,
None
],
"rigidgroups_group_exists"
:
[
NUM_RES
,
None
],
'rigidgroups_gt_exists'
:
[
NUM_RES
,
None
],
"rigidgroups_group_is_ambiguous"
:
[
NUM_RES
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
,
None
],
"rigidgroups_gt_exists"
:
[
NUM_RES
,
None
],
'seq_length'
:
[],
"rigidgroups_gt_frames"
:
[
NUM_RES
,
None
,
None
,
None
],
'seq_mask'
:
[
NUM_RES
],
"seq_length"
:
[],
'target_feat'
:
[
NUM_RES
,
None
],
"seq_mask"
:
[
NUM_RES
],
'template_aatype'
:
[
NUM_TEMPLATES
,
NUM_RES
],
"target_feat"
:
[
NUM_RES
,
None
],
'template_all_atom_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
"template_aatype"
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_all_atom_positions'
:
"template_all_atom_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
"template_all_atom_positions"
:
[
'template_alt_torsion_angles_sin_cos'
:
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
],
'template_backbone_affine_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_alt_torsion_angles_sin_cos"
:
[
'template_backbone_affine_tensor'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
],
'template_mask'
:
[
NUM_TEMPLATES
],
"template_backbone_affine_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_pseudo_beta'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
"template_backbone_affine_tensor"
:
[
'template_pseudo_beta_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
'template_sum_probs'
:
[
NUM_TEMPLATES
,
None
],
],
'template_torsion_angles_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
"template_mask"
:
[
NUM_TEMPLATES
],
'template_torsion_angles_sin_cos'
:
"template_pseudo_beta"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
"template_pseudo_beta_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"template_sum_probs"
:
[
NUM_TEMPLATES
,
None
],
'use_clamped_fape'
:
[],
"template_torsion_angles_mask"
:
[
},
NUM_TEMPLATES
,
NUM_RES
,
None
,
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
num_recycle
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_mask'
,
],
],
'unsupervised_features'
:
[
"template_torsion_angles_sin_cos"
:
[
'aatype'
,
'residue_index'
,
'msa'
,
'num_alignments'
,
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
],
'use_templates'
:
templates_enabled
,
"true_msa"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'use_template_torsion_angles'
:
embed_template_torsion_angles
,
"use_clamped_fape"
:
[],
'supervised_features'
:
[
},
'all_atom_mask'
,
'all_atom_positions'
,
'resolution'
,
"masked_msa"
:
{
'use_clamped_fape'
,
"profile_prob"
:
0.1
,
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
"max_extra_msa"
:
1024
,
"msa_cluster_features"
:
True
,
"num_recycle"
:
num_recycle
,
"reduce_msa_clusters_by_max_templates"
:
False
,
"resample_msa_in_recycling"
:
True
,
"template_features"
:
[
"template_all_atom_positions"
,
"template_sum_probs"
,
"template_aatype"
,
"template_all_atom_mask"
,
],
"unsupervised_features"
:
[
"aatype"
,
"residue_index"
,
"msa"
,
"num_alignments"
,
"seq_length"
,
"between_segment_residues"
,
"deletion_matrix"
,
],
"use_templates"
:
templates_enabled
,
"use_template_torsion_angles"
:
embed_template_torsion_angles
,
"supervised_features"
:
[
"all_atom_mask"
,
"all_atom_positions"
,
"resolution"
,
"use_clamped_fape"
,
],
],
},
},
'predict'
:
{
"predict"
:
{
'fixed_size'
:
True
,
"fixed_size"
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
'max_msa_clusters'
:
512
,
"max_msa_clusters"
:
512
,
'max_templates'
:
4
,
"max_templates"
:
4
,
'num_ensemble'
:
1
,
"num_ensemble"
:
1
,
'crop'
:
False
,
"crop"
:
False
,
'crop_size'
:
None
,
"crop_size"
:
None
,
'supervised'
:
False
,
"supervised"
:
False
,
},
},
'eval'
:
{
"eval"
:
{
'fixed_size'
:
True
,
"fixed_size"
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
'max_msa_clusters'
:
512
,
"max_msa_clusters"
:
512
,
'max_templates'
:
4
,
"max_templates"
:
4
,
'num_ensemble'
:
1
,
"num_ensemble"
:
1
,
'crop'
:
False
,
"crop"
:
False
,
'crop_size'
:
None
,
"crop_size"
:
None
,
'supervised'
:
True
,
"supervised"
:
True
,
},
},
'train'
:
{
"train"
:
{
'fixed_size'
:
True
,
"fixed_size"
:
True
,
'subsample_templates'
:
True
,
"subsample_templates"
:
True
,
'masked_msa_replace_fraction'
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
'max_msa_clusters'
:
512
,
"max_msa_clusters"
:
512
,
'max_templates'
:
4
,
"max_templates"
:
4
,
'num_ensemble'
:
1
,
"num_ensemble"
:
1
,
'crop'
:
True
,
"crop"
:
True
,
'crop_size'
:
256
,
"crop_size"
:
256
,
'supervised'
:
True
,
"supervised"
:
True
,
},
},
'data_module'
:
{
"data_module"
:
{
'use_small_bfd'
:
False
,
"use_small_bfd"
:
False
,
'data_loaders'
:
{
"data_loaders"
:
{
'batch_size'
:
1
,
"batch_size"
:
1
,
'num_workers'
:
1
,
"num_workers"
:
1
,
},
},
},
}
},
},
# Recurring FieldReferences that can be changed globally here
# Recurring FieldReferences that can be changed globally here
'
globals
'
:
{
"
globals
"
:
{
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
"
chunk_size
"
:
chunk_size
,
'
c_z
'
:
c_z
,
"
c_z
"
:
c_z
,
'
c_m
'
:
c_m
,
"
c_m
"
:
c_m
,
'
c_t
'
:
c_t
,
"
c_t
"
:
c_t
,
'
c_e
'
:
c_e
,
"
c_e
"
:
c_e
,
'
c_s
'
:
c_s
,
"
c_s
"
:
c_s
,
'
eps
'
:
eps
,
"
eps
"
:
eps
,
},
},
'
model
'
:
{
"
model
"
:
{
'
num_recycle
'
:
num_recycle
,
"
num_recycle
"
:
num_recycle
,
'
_mask_trans
'
:
False
,
"
_mask_trans
"
:
False
,
'
input_embedder
'
:
{
"
input_embedder
"
:
{
'
tf_dim
'
:
22
,
"
tf_dim
"
:
22
,
'
msa_dim
'
:
49
,
"
msa_dim
"
:
49
,
'
c_z
'
:
c_z
,
"
c_z
"
:
c_z
,
'
c_m
'
:
c_m
,
"
c_m
"
:
c_m
,
'
relpos_k
'
:
32
,
"
relpos_k
"
:
32
,
},
},
'
recycling_embedder
'
:
{
"
recycling_embedder
"
:
{
'
c_z
'
:
c_z
,
"
c_z
"
:
c_z
,
'
c_m
'
:
c_m
,
"
c_m
"
:
c_m
,
'
min_bin
'
:
3.25
,
"
min_bin
"
:
3.25
,
'
max_bin
'
:
20.75
,
"
max_bin
"
:
20.75
,
'
no_bins
'
:
15
,
"
no_bins
"
:
15
,
'
inf
'
:
1e8
,
"
inf
"
:
1e8
,
},
},
'
template
'
:
{
"
template
"
:
{
'
distogram
'
:
{
"
distogram
"
:
{
'
min_bin
'
:
3.25
,
"
min_bin
"
:
3.25
,
'
max_bin
'
:
50.75
,
"
max_bin
"
:
50.75
,
'
no_bins
'
:
39
,
"
no_bins
"
:
39
,
},
},
'
template_angle_embedder
'
:
{
"
template_angle_embedder
"
:
{
# DISCREPANCY: c_in is supposed to be 51.
# DISCREPANCY: c_in is supposed to be 51.
'
c_in
'
:
57
,
"
c_in
"
:
57
,
'
c_out
'
:
c_m
,
"
c_out
"
:
c_m
,
},
},
'
template_pair_embedder
'
:
{
"
template_pair_embedder
"
:
{
'
c_in
'
:
88
,
"
c_in
"
:
88
,
'
c_out
'
:
c_t
,
"
c_out
"
:
c_t
,
},
},
'
template_pair_stack
'
:
{
"
template_pair_stack
"
:
{
'
c_t
'
:
c_t
,
"
c_t
"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
# as 64. In the code, it's 16.
'
c_hidden_tri_att
'
:
16
,
"
c_hidden_tri_att
"
:
16
,
'
c_hidden_tri_mul
'
:
64
,
"
c_hidden_tri_mul
"
:
64
,
'
no_blocks
'
:
2
,
"
no_blocks
"
:
2
,
'
no_heads
'
:
4
,
"
no_heads
"
:
4
,
'
pair_transition_n
'
:
2
,
"
pair_transition_n
"
:
2
,
'
dropout_rate
'
:
0.25
,
"
dropout_rate
"
:
0.25
,
'
blocks_per_ckpt
'
:
blocks_per_ckpt
,
"
blocks_per_ckpt
"
:
blocks_per_ckpt
,
'
chunk_size
'
:
chunk_size
,
"
chunk_size
"
:
chunk_size
,
'
inf
'
:
1e5
,
#
1e9,
"
inf
"
:
1e5
,
#
1e9,
},
},
'
template_pointwise_attention
'
:
{
"
template_pointwise_attention
"
:
{
'
c_t
'
:
c_t
,
"
c_t
"
:
c_t
,
'
c_z
'
:
c_z
,
"
c_z
"
:
c_z
,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
# It's actually 16.
'c_hidden'
:
16
,
"c_hidden"
:
16
,
'no_heads'
:
4
,
"no_heads"
:
4
,
'chunk_size'
:
chunk_size
,
"chunk_size"
:
chunk_size
,
'inf'
:
1e5
,
#1e9,
"inf"
:
1e5
,
# 1e9,
},
},
'inf'
:
1e5
,
#1e9,
"inf"
:
1e5
,
# 1e9,
'eps'
:
eps
,
#1e-6,
"eps"
:
eps
,
# 1e-6,
'enabled'
:
templates_enabled
,
"enabled"
:
templates_enabled
,
'embed_angles'
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
},
},
'extra_msa'
:
{
"extra_msa"
:
{
'extra_msa_embedder'
:
{
"extra_msa_embedder"
:
{
'c_in'
:
25
,
"c_in"
:
25
,
'c_out'
:
c_e
,
"c_out"
:
c_e
,
},
},
'extra_msa_stack'
:
{
"extra_msa_stack"
:
{
'c_m'
:
c_e
,
"c_m"
:
c_e
,
'c_z'
:
c_z
,
"c_z"
:
c_z
,
'c_hidden_msa_att'
:
8
,
"c_hidden_msa_att"
:
8
,
'c_hidden_opm'
:
32
,
"c_hidden_opm"
:
32
,
'c_hidden_mul'
:
128
,
"c_hidden_mul"
:
128
,
'c_hidden_pair_att'
:
32
,
"c_hidden_pair_att"
:
32
,
'no_heads_msa'
:
8
,
"no_heads_msa"
:
8
,
'no_heads_pair'
:
4
,
"no_heads_pair"
:
4
,
'no_blocks'
:
4
,
"no_blocks"
:
4
,
'transition_n'
:
4
,
"transition_n"
:
4
,
'msa_dropout'
:
0.15
,
"msa_dropout"
:
0.15
,
'pair_dropout'
:
0.25
,
"pair_dropout"
:
0.25
,
'blocks_per_ckpt'
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
'chunk_size'
:
chunk_size
,
"chunk_size"
:
chunk_size
,
'inf'
:
1e5
,
#1e9,
"inf"
:
1e5
,
# 1e9,
'eps'
:
eps
,
#1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
'enabled'
:
True
,
"enabled"
:
True
,
},
},
'evoformer_stack'
:
{
"evoformer_stack"
:
{
'c_m'
:
c_m
,
"c_m"
:
c_m
,
'c_z'
:
c_z
,
"c_z"
:
c_z
,
'c_hidden_msa_att'
:
32
,
"c_hidden_msa_att"
:
32
,
'c_hidden_opm'
:
32
,
"c_hidden_opm"
:
32
,
'c_hidden_mul'
:
128
,
"c_hidden_mul"
:
128
,
'c_hidden_pair_att'
:
32
,
"c_hidden_pair_att"
:
32
,
'c_s'
:
c_s
,
"c_s"
:
c_s
,
'no_heads_msa'
:
8
,
"no_heads_msa"
:
8
,
'no_heads_pair'
:
4
,
"no_heads_pair"
:
4
,
'no_blocks'
:
48
,
"no_blocks"
:
48
,
'transition_n'
:
4
,
"transition_n"
:
4
,
'msa_dropout'
:
0.15
,
"msa_dropout"
:
0.15
,
'pair_dropout'
:
0.25
,
"pair_dropout"
:
0.25
,
'blocks_per_ckpt'
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
'chunk_size'
:
chunk_size
,
"chunk_size"
:
chunk_size
,
'inf'
:
1e5
,
#1e9,
"inf"
:
1e5
,
# 1e9,
'eps'
:
eps
,
#1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
'structure_module'
:
{
"structure_module"
:
{
'c_s'
:
c_s
,
"c_s"
:
c_s
,
'c_z'
:
c_z
,
"c_z"
:
c_z
,
'c_ipa'
:
16
,
"c_ipa"
:
16
,
'c_resnet'
:
128
,
"c_resnet"
:
128
,
'no_heads_ipa'
:
12
,
"no_heads_ipa"
:
12
,
'no_qk_points'
:
4
,
"no_qk_points"
:
4
,
'no_v_points'
:
8
,
"no_v_points"
:
8
,
'dropout_rate'
:
0.1
,
"dropout_rate"
:
0.1
,
'no_blocks'
:
8
,
"no_blocks"
:
8
,
'no_transition_layers'
:
1
,
"no_transition_layers"
:
1
,
'no_resnet_blocks'
:
2
,
"no_resnet_blocks"
:
2
,
'no_angles'
:
7
,
"no_angles"
:
7
,
'trans_scale_factor'
:
10
,
"trans_scale_factor"
:
10
,
'epsilon'
:
eps
,
#1e-12,
"epsilon"
:
eps
,
# 1e-12,
'inf'
:
1e5
,
"inf"
:
1e5
,
},
},
'heads'
:
{
"heads"
:
{
'lddt'
:
{
"lddt"
:
{
'no_bins'
:
50
,
"no_bins"
:
50
,
'c_in'
:
c_s
,
"c_in"
:
c_s
,
'c_hidden'
:
128
,
"c_hidden"
:
128
,
},
},
'distogram'
:
{
"distogram"
:
{
'c_z'
:
c_z
,
"c_z"
:
c_z
,
'no_bins'
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
},
},
'tm'
:
{
"tm"
:
{
'c_z'
:
c_z
,
"c_z"
:
c_z
,
'no_bins'
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
'enabled'
:
False
,
"enabled"
:
False
,
},
},
'masked_msa'
:
{
"masked_msa"
:
{
'c_m'
:
c_m
,
"c_m"
:
c_m
,
'c_out'
:
23
,
"c_out"
:
23
,
},
},
'experimentally_resolved'
:
{
"experimentally_resolved"
:
{
'c_s'
:
c_s
,
"c_s"
:
c_s
,
'c_out'
:
37
,
"c_out"
:
37
,
},
},
},
},
},
},
'relax'
:
{
"relax"
:
{
'max_iterations'
:
0
,
# no max
"max_iterations"
:
0
,
# no max
'tolerance'
:
2.39
,
"tolerance"
:
2.39
,
'stiffness'
:
10.0
,
"stiffness"
:
10.0
,
'max_outer_iterations'
:
20
,
"max_outer_iterations"
:
20
,
'exclude_residues'
:
[],
"exclude_residues"
:
[],
},
},
'loss'
:
{
"loss"
:
{
'distogram'
:
{
"distogram"
:
{
'min_bin'
:
2.3125
,
"min_bin"
:
2.3125
,
'max_bin'
:
21.6875
,
"max_bin"
:
21.6875
,
'no_bins'
:
64
,
"no_bins"
:
64
,
'eps'
:
eps
,
#1e-6,
"eps"
:
eps
,
# 1e-6,
'weight'
:
0.3
,
"weight"
:
0.3
,
},
},
'experimentally_resolved'
:
{
"experimentally_resolved"
:
{
'eps'
:
eps
,
#1e-8,
"eps"
:
eps
,
# 1e-8,
'min_resolution'
:
0.1
,
"min_resolution"
:
0.1
,
'max_resolution'
:
3.0
,
"max_resolution"
:
3.0
,
'weight'
:
0.
,
"weight"
:
0.0
,
},
},
'fape'
:
{
"fape"
:
{
'backbone'
:
{
"backbone"
:
{
'clamp_distance'
:
10.
,
"clamp_distance"
:
10.0
,
'loss_unit_distance'
:
10.
,
"loss_unit_distance"
:
10.0
,
'weight'
:
0.5
,
"weight"
:
0.5
,
},
},
'sidechain'
:
{
"sidechain"
:
{
'clamp_distance'
:
10.
,
"clamp_distance"
:
10.0
,
'length_scale'
:
10.
,
"length_scale"
:
10.0
,
'weight'
:
0.5
,
"weight"
:
0.5
,
},
},
'eps'
:
1e-4
,
"eps"
:
1e-4
,
'weight'
:
1.0
,
"weight"
:
1.0
,
},
},
'lddt'
:
{
"lddt"
:
{
'min_resolution'
:
0.1
,
"min_resolution"
:
0.1
,
'max_resolution'
:
3.0
,
"max_resolution"
:
3.0
,
'cutoff'
:
15.
,
"cutoff"
:
15.0
,
'no_bins'
:
50
,
"no_bins"
:
50
,
'eps'
:
eps
,
#1e-10,
"eps"
:
eps
,
# 1e-10,
'weight'
:
0.01
,
"weight"
:
0.01
,
},
},
'masked_msa'
:
{
"masked_msa"
:
{
'eps'
:
eps
,
#1e-8,
"eps"
:
eps
,
# 1e-8,
'weight'
:
2.0
,
"weight"
:
2.0
,
},
},
'supervised_chi'
:
{
"supervised_chi"
:
{
'chi_weight'
:
0.5
,
"chi_weight"
:
0.5
,
'angle_norm_weight'
:
0.01
,
"angle_norm_weight"
:
0.01
,
'eps'
:
eps
,
#1e-6,
"eps"
:
eps
,
# 1e-6,
'weight'
:
1.0
,
"weight"
:
1.0
,
},
},
'violation'
:
{
"violation"
:
{
'violation_tolerance_factor'
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
'clash_overlap_tolerance'
:
1.5
,
"clash_overlap_tolerance"
:
1.5
,
'eps'
:
eps
,
#1e-6,
"eps"
:
eps
,
# 1e-6,
'weight'
:
0.
,
"weight"
:
0.0
,
},
},
'tm'
:
{
"tm"
:
{
'max_bin'
:
31
,
"max_bin"
:
31
,
'no_bins'
:
64
,
"no_bins"
:
64
,
'min_resolution'
:
0.1
,
"min_resolution"
:
0.1
,
'max_resolution'
:
3.0
,
"max_resolution"
:
3.0
,
'eps'
:
eps
,
#1e-8,
"eps"
:
eps
,
# 1e-8,
'weight'
:
0.
,
"weight"
:
0.0
,
},
},
'eps'
:
eps
,
"eps"
:
eps
,
},
},
'ema'
:
{
"ema"
:
{
"decay"
:
0.999
},
'decay'
:
0.999
}
},
)
})
openfold/data/data_pipeline.py
View file @
07e64267
...
@@ -27,45 +27,45 @@ from openfold.np import residue_constants
...
@@ -27,45 +27,45 @@ from openfold.np import residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
def
make_sequence_features
(
sequence
:
str
,
sequence
:
str
,
description
:
str
,
num_res
:
int
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
"""Construct a feature dict of sequence features."""
features
=
{}
features
=
{}
features
[
'
aatype
'
]
=
residue_constants
.
sequence_to_onehot
(
features
[
"
aatype
"
]
=
residue_constants
.
sequence_to_onehot
(
sequence
=
sequence
,
sequence
=
sequence
,
mapping
=
residue_constants
.
restype_order_with_x
,
mapping
=
residue_constants
.
restype_order_with_x
,
map_unknown_to_x
=
True
map_unknown_to_x
=
True
,
)
)
features
[
'
between_segment_residues
'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
"
between_segment_residues
"
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'
domain_name
'
]
=
np
.
array
(
features
[
"
domain_name
"
]
=
np
.
array
(
[
description
.
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
[
description
.
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
)
features
[
'
residue_index
'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
"
residue_index
"
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'
seq_length
'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"
seq_length
"
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'
sequence
'
]
=
np
.
array
(
features
[
"
sequence
"
]
=
np
.
array
(
[
sequence
.
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
[
sequence
.
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
)
return
features
return
features
def
make_mmcif_features
(
def
make_mmcif_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
chain_id
:
str
)
->
FeatureDict
:
)
->
FeatureDict
:
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
description
=
'_'
.
join
([
mmcif_object
.
file_id
,
chain_id
])
description
=
"_"
.
join
([
mmcif_object
.
file_id
,
chain_id
])
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
mmcif_feats
=
{}
mmcif_feats
=
{}
mmcif_feats
.
update
(
make_sequence_features
(
mmcif_feats
.
update
(
make_sequence_features
(
sequence
=
input_sequence
,
sequence
=
input_sequence
,
description
=
description
,
description
=
description
,
num_res
=
num_res
,
num_res
=
num_res
,
))
)
)
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
...
@@ -78,7 +78,7 @@ def make_mmcif_features(
...
@@ -78,7 +78,7 @@ def make_mmcif_features(
)
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
)
return
mmcif_feats
return
mmcif_feats
...
@@ -86,17 +86,20 @@ def make_mmcif_features(
...
@@ -86,17 +86,20 @@ def make_mmcif_features(
def
make_msa_features
(
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
],
)
->
FeatureDict
:
"""Constructs a feature dict of MSA features."""
"""Constructs a feature dict of MSA features."""
if
not
msas
:
if
not
msas
:
raise
ValueError
(
'
At least one MSA must be provided.
'
)
raise
ValueError
(
"
At least one MSA must be provided.
"
)
int_msa
=
[]
int_msa
=
[]
deletion_matrix
=
[]
deletion_matrix
=
[]
seen_sequences
=
set
()
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
for
msa_index
,
msa
in
enumerate
(
msas
):
if
not
msa
:
if
not
msa
:
raise
ValueError
(
f
'MSA
{
msa_index
}
must contain at least one sequence.'
)
raise
ValueError
(
f
"MSA
{
msa_index
}
must contain at least one sequence."
)
for
sequence_index
,
sequence
in
enumerate
(
msa
):
for
sequence_index
,
sequence
in
enumerate
(
msa
):
if
sequence
in
seen_sequences
:
if
sequence
in
seen_sequences
:
continue
continue
...
@@ -109,17 +112,19 @@ def make_msa_features(
...
@@ -109,17 +112,19 @@ def make_msa_features(
num_res
=
len
(
msas
[
0
][
0
])
num_res
=
len
(
msas
[
0
][
0
])
num_alignments
=
len
(
int_msa
)
num_alignments
=
len
(
int_msa
)
features
=
{}
features
=
{}
features
[
'
deletion_matrix_int
'
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
"
deletion_matrix_int
"
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
'
msa
'
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
"
msa
"
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
'
num_alignments
'
]
=
np
.
array
(
features
[
"
num_alignments
"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
)
return
features
return
features
class
AlignmentRunner
:
class
AlignmentRunner
:
""" Runs alignment tools and saves the results """
"""Runs alignment tools and saves the results"""
def
__init__
(
self
,
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
hhblits_binary_path
:
str
,
hhsearch_binary_path
:
str
,
hhsearch_binary_path
:
str
,
...
@@ -161,105 +166,109 @@ class AlignmentRunner:
...
@@ -161,105 +166,109 @@ class AlignmentRunner:
)
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
databases
=
[
pdb70_database_path
]
)
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
def
run
(
self
,
def
run
(
self
,
fasta_path
:
str
,
fasta_path
:
str
,
output_dir
:
str
,
output_dir
:
str
,
):
):
"""Runs alignment tools on a sequence"""
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'
sto
'
],
max_sequences
=
self
.
uniref_max_hits
jackhmmer_uniref90_result
[
"
sto
"
],
max_sequences
=
self
.
uniref_max_hits
)
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
'
uniref90_hits.a3m
'
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"
uniref90_hits.a3m
"
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
with
open
(
uniref90_out_path
,
"w"
)
as
f
:
f
.
write
(
uniref90_msa_as_a3m
)
f
.
write
(
uniref90_msa_as_a3m
)
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
'
sto
'
],
max_sequences
=
self
.
mgnify_max_hits
jackhmmer_mgnify_result
[
"
sto
"
],
max_sequences
=
self
.
mgnify_max_hits
)
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
'
mgnify_hits.a3m
'
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"
mgnify_hits.a3m
"
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
with
open
(
mgnify_out_path
,
"w"
)
as
f
:
f
.
write
(
mgnify_msa_as_a3m
)
f
.
write
(
mgnify_msa_as_a3m
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
'
pdb70_hits.hhr
'
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
"
pdb70_hits.hhr
"
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
with
open
(
pdb70_out_path
,
"w"
)
as
f
:
f
.
write
(
hhsearch_result
)
f
.
write
(
hhsearch_result
)
if
self
.
_use_small_bfd
:
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'small_bfd_hits.sto'
)
fasta_path
with
open
(
bfd_out_path
,
'w'
)
as
f
:
)[
0
]
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
"sto"
])
else
:
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
hhblits_bfd_uniclust_result
=
(
if
(
output_dir
is
not
None
):
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'bfd_uniclust_hits.a3m'
)
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
if
output_dir
is
not
None
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
class
DataPipeline
:
class
DataPipeline
:
"""Assembles input features."""
"""Assembles input features."""
def
__init__
(
self
,
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
use_small_bfd
:
bool
,
):
):
self
.
template_featurizer
=
template_featurizer
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
self
.
use_small_bfd
=
use_small_bfd
def
_parse_alignment_output
(
self
,
def
_parse_alignment_output
(
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
'uniref90_hits.a3m'
)
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
with
open
(
uniref90_out_path
,
'r'
)
as
f
:
with
open
(
uniref90_out_path
,
"r"
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
f
.
read
()
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
'mgnify_hits.a3m'
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
"mgnify_hits.a3m"
)
with
open
(
mgnify_out_path
,
'r'
)
as
f
:
with
open
(
mgnify_out_path
,
"r"
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
f
.
read
()
)
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
'pdb70_hits.hhr'
)
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
"pdb70_hits.hhr"
)
with
open
(
pdb70_out_path
,
'r'
)
as
f
:
with
open
(
pdb70_out_path
,
"r"
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
())
f
.
read
()
)
if
(
self
.
use_small_bfd
)
:
if
self
.
use_small_bfd
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'
small_bfd_hits.sto
'
)
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"
small_bfd_hits.sto
"
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
f
.
read
()
)
)
else
:
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'bfd_uniclust_hits.a3m'
)
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
f
.
read
()
)
return
{
return
{
'
uniref90_msa
'
:
uniref90_msa
,
"
uniref90_msa
"
:
uniref90_msa
,
'
uniref90_deletion_matrix
'
:
uniref90_deletion_matrix
,
"
uniref90_deletion_matrix
"
:
uniref90_deletion_matrix
,
'
mgnify_msa
'
:
mgnify_msa
,
"
mgnify_msa
"
:
mgnify_msa
,
'
mgnify_deletion_matrix
'
:
mgnify_deletion_matrix
,
"
mgnify_deletion_matrix
"
:
mgnify_deletion_matrix
,
'
hhsearch_hits
'
:
hhsearch_hits
,
"
hhsearch_hits
"
:
hhsearch_hits
,
'
bfd_msa
'
:
bfd_msa
,
"
bfd_msa
"
:
bfd_msa
,
'
bfd_deletion_matrix
'
:
bfd_deletion_matrix
,
"
bfd_deletion_matrix
"
:
bfd_deletion_matrix
,
}
}
def
process_fasta
(
self
,
def
process_fasta
(
self
,
fasta_path
:
str
,
fasta_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
)
->
FeatureDict
:
...
@@ -269,7 +278,8 @@ class DataPipeline:
...
@@ -269,7 +278,8 @@ class DataPipeline:
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
raise
ValueError
(
f
'More than one input sequence found in
{
fasta_path
}
.'
)
f
"More than one input sequence found in
{
fasta_path
}
."
)
input_sequence
=
input_seqs
[
0
]
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
...
@@ -280,30 +290,31 @@ class DataPipeline:
...
@@ -280,30 +290,31 @@ class DataPipeline:
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
'
hhsearch_hits
'
]
hits
=
alignments
[
"
hhsearch_hits
"
],
)
)
sequence_features
=
make_sequence_features
(
sequence_features
=
make_sequence_features
(
sequence
=
input_sequence
,
sequence
=
input_sequence
,
description
=
input_description
,
description
=
input_description
,
num_res
=
num_res
num_res
=
num_res
,
)
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
(
msas
=
(
alignments
[
'
uniref90_msa
'
],
alignments
[
"
uniref90_msa
"
],
alignments
[
'
bfd_msa
'
],
alignments
[
"
bfd_msa
"
],
alignments
[
'
mgnify_msa
'
]
alignments
[
"
mgnify_msa
"
],
),
),
deletion_matrices
=
(
deletion_matrices
=
(
alignments
[
'
uniref90_deletion_matrix
'
],
alignments
[
"
uniref90_deletion_matrix
"
],
alignments
[
'
bfd_deletion_matrix
'
],
alignments
[
"
bfd_deletion_matrix
"
],
alignments
[
'
mgnify_deletion_matrix
'
]
alignments
[
"
mgnify_deletion_matrix
"
],
)
)
,
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
data
}
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
data
}
def
process_mmcif
(
self
,
def
process_mmcif
(
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
...
@@ -314,13 +325,11 @@ class DataPipeline:
...
@@ -314,13 +325,11 @@ class DataPipeline:
If chain_id is None, it is assumed that there is only one chain
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
in the object. Otherwise, a ValueError is thrown.
"""
"""
if
(
chain_id
is
None
)
:
if
chain_id
is
None
:
chains
=
mmcif
.
structure
.
get_chains
()
chains
=
mmcif
.
structure
.
get_chains
()
chain
=
next
(
chains
,
None
)
chain
=
next
(
chains
,
None
)
if
(
chain
is
None
):
if
chain
is
None
:
raise
ValueError
(
raise
ValueError
(
"No chains in mmCIF file"
)
'No chains in mmCIF file'
)
chain_id
=
chain
.
id
chain_id
=
chain
.
id
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
...
@@ -332,20 +341,20 @@ class DataPipeline:
...
@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
'
hhsearch_hits
'
]
hits
=
alignments
[
"
hhsearch_hits
"
],
)
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
(
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
"uniref90_msa"
],
alignments
[
'bfd_msa'
],
alignments
[
"bfd_msa"
],
alignments
[
'mgnify_msa'
]
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
data
,
**
msa_features
}
return
{
**
mmcif_feats
,
**
templates_result
.
data
,
**
msa_features
}
openfold/data/data_transforms.py
View file @
07e64267
...
@@ -23,13 +23,23 @@ import torch
...
@@ -23,13 +23,23 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.tools
import
residue_constants
as
rc
from
openfold.tools
import
residue_constants
as
rc
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
,
batched_gather
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
batched_gather
,
)
MSA_FEATURE_NAMES
=
[
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
"msa"
,
"deletion_matrix"
,
"msa_mask"
,
"msa_row_mask"
,
"bert_mask"
,
"true_msa"
,
]
]
def
cast_to_64bit_ints
(
protein
):
def
cast_to_64bit_ints
(
protein
):
# We keep all ints as int64
# We keep all ints as int64
for
k
,
v
in
protein
.
items
():
for
k
,
v
in
protein
.
items
():
...
@@ -37,21 +47,27 @@ def cast_to_64bit_ints(protein):
...
@@ -37,21 +47,27 @@ def cast_to_64bit_ints(protein):
protein
[
k
]
=
v
.
type
(
torch
.
int64
)
protein
[
k
]
=
v
.
type
(
torch
.
int64
)
return
protein
return
protein
def
make_one_hot
(
x
,
num_classes
):
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
return
x_one_hot
def
make_seq_mask
(
protein
):
def
make_seq_mask
(
protein
):
protein
[
'seq_mask'
]
=
torch
.
ones
(
protein
[
'aatype'
].
shape
,
dtype
=
torch
.
float32
)
protein
[
"seq_mask"
]
=
torch
.
ones
(
protein
[
"aatype"
].
shape
,
dtype
=
torch
.
float32
)
return
protein
return
protein
def
make_template_mask
(
protein
):
def
make_template_mask
(
protein
):
protein
[
'
template_mask
'
]
=
torch
.
ones
(
protein
[
"
template_mask
"
]
=
torch
.
ones
(
protein
[
'
template_aatype
'
].
shape
[
0
],
dtype
=
torch
.
float32
protein
[
"
template_aatype
"
].
shape
[
0
],
dtype
=
torch
.
float32
)
)
return
protein
return
protein
def
curry1
(
f
):
def
curry1
(
f
):
"""Supply all arguments but the first."""
"""Supply all arguments but the first."""
...
@@ -60,137 +76,167 @@ def curry1(f):
...
@@ -60,137 +76,167 @@ def curry1(f):
return
fc
return
fc
@
curry1
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'
is_distillation
'
]
=
torch
.
tensor
(
protein
[
"
is_distillation
"
]
=
torch
.
tensor
(
float
(
distillation
),
dtype
=
torch
.
float32
float
(
distillation
),
dtype
=
torch
.
float32
)
)
return
protein
return
protein
def
make_all_atom_aatype
(
protein
):
def
make_all_atom_aatype
(
protein
):
protein
[
'
all_atom_aatype
'
]
=
protein
[
'
aatype
'
]
protein
[
"
all_atom_aatype
"
]
=
protein
[
"
aatype
"
]
return
protein
return
protein
def
fix_templates_aatype
(
protein
):
def
fix_templates_aatype
(
protein
):
# Map one-hot to indices
# Map one-hot to indices
num_templates
=
protein
[
'template_aatype'
].
shape
[
0
]
num_templates
=
protein
[
"template_aatype"
].
shape
[
0
]
protein
[
'template_aatype'
]
=
torch
.
argmax
(
protein
[
'template_aatype'
],
dim
=-
1
)
protein
[
"template_aatype"
]
=
torch
.
argmax
(
protein
[
"template_aatype"
],
dim
=-
1
)
# Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
n
ew_order_list
,
dtype
=
torch
.
int64
n
um_templates
,
-
1
)
.
expand
(
num_templates
,
-
1
)
)
protein
[
'
template_aatype
'
]
=
torch
.
gather
(
protein
[
"
template_aatype
"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
'
template_aatype
'
]
new_order
,
1
,
index
=
protein
[
"
template_aatype
"
]
)
)
return
protein
return
protein
def
correct_msa_restypes
(
protein
):
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as rc."""
"""Correct MSA restype to have the same order as rc."""
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
'
msa
'
].
shape
[
1
],
dtype
=
protein
[
'
msa
'
].
dtype
[
new_order_list
]
*
protein
[
"
msa
"
].
shape
[
1
],
dtype
=
protein
[
"
msa
"
].
dtype
).
transpose
(
0
,
1
)
).
transpose
(
0
,
1
)
protein
[
'
msa
'
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
'
msa
'
])
protein
[
"
msa
"
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
"
msa
"
])
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
perm_matrix
[
range
(
len
(
new_order_list
)),
new_order_list
]
=
1.
perm_matrix
[
range
(
len
(
new_order_list
)),
new_order_list
]
=
1.
0
for
k
in
protein
:
for
k
in
protein
:
if
'
profile
'
in
k
:
if
"
profile
"
in
k
:
num_dim
=
protein
[
k
].
shape
.
as_list
()[
-
1
]
num_dim
=
protein
[
k
].
shape
.
as_list
()[
-
1
]
assert
num_dim
in
[
20
,
21
,
22
],
(
assert
num_dim
in
[
'num_dim for %s out of expected range: %s'
%
(
k
,
num_dim
))
20
,
21
,
22
,
],
"num_dim for %s out of expected range: %s"
%
(
k
,
num_dim
)
protein
[
k
]
=
torch
.
dot
(
protein
[
k
],
perm_matrix
[:
num_dim
,
:
num_dim
])
protein
[
k
]
=
torch
.
dot
(
protein
[
k
],
perm_matrix
[:
num_dim
,
:
num_dim
])
return
protein
return
protein
def
squeeze_features
(
protein
):
def
squeeze_features
(
protein
):
"""Remove singleton and repeated dimensions in protein features."""
"""Remove singleton and repeated dimensions in protein features."""
protein
[
'
aatype
'
]
=
torch
.
argmax
(
protein
[
'
aatype
'
],
dim
=-
1
)
protein
[
"
aatype
"
]
=
torch
.
argmax
(
protein
[
"
aatype
"
],
dim
=-
1
)
for
k
in
[
for
k
in
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
"domain_name"
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
"msa"
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_mask'
]:
"num_alignments"
,
"seq_length"
,
"sequence"
,
"superfamily"
,
"deletion_matrix"
,
"resolution"
,
"between_segment_residues"
,
"residue_index"
,
"template_all_atom_mask"
,
]:
if
k
in
protein
:
if
k
in
protein
:
final_dim
=
protein
[
k
].
shape
[
-
1
]
final_dim
=
protein
[
k
].
shape
[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
protein
[
k
]
=
torch
.
squeeze
(
protein
[
k
],
dim
=-
1
)
for
k
in
[
'
seq_length
'
,
'
num_alignments
'
]:
for
k
in
[
"
seq_length
"
,
"
num_alignments
"
]:
if
k
in
protein
:
if
k
in
protein
:
protein
[
k
]
=
protein
[
k
][
0
]
protein
[
k
]
=
protein
[
k
][
0
]
return
protein
return
protein
@
curry1
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a portion of the MSA with 'X'."""
"""Replace a portion of the MSA with 'X'."""
msa_mask
=
(
torch
.
rand
(
protein
[
'
msa
'
].
shape
)
<
replace_proportion
)
msa_mask
=
torch
.
rand
(
protein
[
"
msa
"
].
shape
)
<
replace_proportion
x_idx
=
20
x_idx
=
20
gap_idx
=
21
gap_idx
=
21
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
'msa'
]
!=
gap_idx
)
msa_mask
=
torch
.
logical_and
(
msa_mask
,
protein
[
"msa"
]
!=
gap_idx
)
protein
[
'msa'
]
=
torch
.
where
(
msa_mask
,
torch
.
ones_like
(
protein
[
'msa'
])
*
x_idx
,
protein
[
"msa"
]
=
torch
.
where
(
protein
[
'msa'
])
msa_mask
,
torch
.
ones_like
(
protein
[
"msa"
])
*
x_idx
,
protein
[
"msa"
]
aatype_mask
=
(
torch
.
rand
(
protein
[
'aatype'
].
shape
)
<
replace_proportion
)
)
aatype_mask
=
torch
.
rand
(
protein
[
"aatype"
].
shape
)
<
replace_proportion
protein
[
'aatype'
]
=
torch
.
where
(
protein
[
"aatype"
]
=
torch
.
where
(
aatype_mask
,
torch
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
aatype_mask
,
protein
[
'aatype'
]
torch
.
ones_like
(
protein
[
"aatype"
])
*
x_idx
,
protein
[
"aatype"
],
)
)
return
protein
return
protein
@
curry1
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
"""
num_seq
=
protein
[
"msa"
].
shape
[
0
]
num_seq
=
protein
[
'msa'
].
shape
[
0
]
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
shuffled
=
torch
.
randperm
(
num_seq
-
1
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
])
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
]
)
for
k
in
MSA_FEATURE_NAMES
:
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
:
if
k
in
protein
:
if
keep_extra
:
if
keep_extra
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
not_sel_seq
)
protein
[
"extra_"
+
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
not_sel_seq
)
protein
[
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
sel_seq
)
protein
[
k
]
=
torch
.
index_select
(
protein
[
k
],
0
,
sel_seq
)
return
protein
return
protein
@
curry1
@
curry1
def
crop_extra_msa
(
protein
,
max_extra_msa
):
def
crop_extra_msa
(
protein
,
max_extra_msa
):
num_seq
=
protein
[
'
extra_msa
'
].
shape
[
0
]
num_seq
=
protein
[
"
extra_msa
"
].
shape
[
0
]
num_sel
=
min
(
max_extra_msa
,
num_seq
)
num_sel
=
min
(
max_extra_msa
,
num_seq
)
select_indices
=
torch
.
randperm
(
num_seq
)[:
num_sel
]
select_indices
=
torch
.
randperm
(
num_seq
)[:
num_sel
]
for
k
in
MSA_FEATURE_NAMES
:
for
k
in
MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
if
"extra_"
+
k
in
protein
:
protein
[
'extra_'
+
k
]
=
torch
.
index_select
(
protein
[
'extra_'
+
k
],
0
,
select_indices
)
protein
[
"extra_"
+
k
]
=
torch
.
index_select
(
protein
[
"extra_"
+
k
],
0
,
select_indices
)
return
protein
return
protein
def
delete_extra_msa
(
protein
):
def
delete_extra_msa
(
protein
):
for
k
in
MSA_FEATURE_NAMES
:
for
k
in
MSA_FEATURE_NAMES
:
if
'
extra_
'
+
k
in
protein
:
if
"
extra_
"
+
k
in
protein
:
del
protein
[
'
extra_
'
+
k
]
del
protein
[
"
extra_
"
+
k
]
return
protein
return
protein
# Not used in inference
# Not used in inference
@
curry1
@
curry1
def
block_delete_msa
(
protein
,
config
):
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
'
msa
'
].
shape
[
0
]
num_seq
=
protein
[
"
msa
"
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
num_seq
,
dtype
=
torch
.
float32
*
config
.
msa_fraction_per_block
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
).
to
(
torch
.
int32
)
if
config
.
randomize_num_blocks
:
if
config
.
randomize_num_blocks
:
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
nb
=
torch
.
distributions
.
uniform
.
Uniform
(
0
,
config
.
num_blocks
+
1
).
sample
()
else
:
else
:
nb
=
config
.
num_blocks
nb
=
config
.
num_blocks
del_block_starts
=
torch
.
distributions
.
Uniform
(
0
,
num_seq
).
sample
(
nb
)
del_block_starts
=
torch
.
distributions
.
Uniform
(
0
,
num_seq
).
sample
(
nb
)
del_blocks
=
del_block_starts
[:,
None
]
+
torch
.
range
(
block_num_seq
)
del_blocks
=
del_block_starts
[:,
None
]
+
torch
.
range
(
block_num_seq
)
del_blocks
=
torch
.
clip
(
del_blocks
,
0
,
num_seq
-
1
)
del_blocks
=
torch
.
clip
(
del_blocks
,
0
,
num_seq
-
1
)
del_indices
=
torch
.
unique
(
torch
.
sort
(
torch
.
reshape
(
del_blocks
,
[
-
1
])))[
0
]
del_indices
=
torch
.
unique
(
torch
.
sort
(
torch
.
reshape
(
del_blocks
,
[
-
1
])))[
0
]
# Make sure we keep the original sequence
# Make sure we keep the original sequence
...
@@ -206,19 +252,19 @@ def block_delete_msa(protein, config):
...
@@ -206,19 +252,19 @@ def block_delete_msa(protein, config):
return
protein
return
protein
@
curry1
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.
):
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.0
):
weights
=
torch
.
cat
([
weights
=
torch
.
cat
(
torch
.
ones
(
21
),
[
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)],
gap_agreement_weight
*
torch
.
ones
(
1
),
0
,
torch
.
zeros
(
1
)
)
],
0
)
# Make agreement score as weighted Hamming distance
# Make agreement score as weighted Hamming distance
msa_one_hot
=
make_one_hot
(
protein
[
'
msa
'
],
23
)
msa_one_hot
=
make_one_hot
(
protein
[
"
msa
"
],
23
)
sample_one_hot
=
(
protein
[
'
msa_mask
'
][:,:,
None
]
*
msa_one_hot
)
sample_one_hot
=
protein
[
"
msa_mask
"
][:,
:,
None
]
*
msa_one_hot
extra_msa_one_hot
=
make_one_hot
(
protein
[
'
extra_msa
'
],
23
)
extra_msa_one_hot
=
make_one_hot
(
protein
[
"
extra_msa
"
],
23
)
extra_one_hot
=
(
protein
[
'
extra_msa_mask
'
][:,:,
None
]
*
extra_msa_one_hot
)
extra_one_hot
=
protein
[
"
extra_msa_mask
"
][:,
:,
None
]
*
extra_msa_one_hot
num_seq
,
num_res
,
_
=
sample_one_hot
.
shape
num_seq
,
num_res
,
_
=
sample_one_hot
.
shape
extra_num_seq
,
_
,
_
=
extra_one_hot
.
shape
extra_num_seq
,
_
,
_
=
extra_one_hot
.
shape
...
@@ -226,17 +272,20 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
...
@@ -226,17 +272,20 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
# in an optimized fashion to avoid possible memory or computation blowup.
agreement
=
torch
.
matmul
(
agreement
=
torch
.
matmul
(
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
torch
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
torch
.
reshape
(
torch
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]
).
transpose
(
0
,
1
),
).
transpose
(
0
,
1
),
)
)
# Assign each sequence in the extra sequences to the closest MSA sample
# Assign each sequence in the extra sequences to the closest MSA sample
protein
[
'extra_cluster_assignment'
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
protein
[
"extra_cluster_assignment"
]
=
torch
.
argmax
(
agreement
,
dim
=
1
).
to
(
torch
.
int64
)
return
protein
return
protein
def
unsorted_segment_sum
(
data
,
segment_ids
,
num_segments
):
def
unsorted_segment_sum
(
data
,
segment_ids
,
num_segments
):
"""
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
...
@@ -264,123 +313,145 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
...
@@ -264,123 +313,145 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
tensor
=
tensor
.
type
(
data
.
dtype
)
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
return
tensor
@
curry1
@
curry1
def
summarize_clusters
(
protein
):
def
summarize_clusters
(
protein
):
"""Produce profile and deletion_matrix_mean within each cluster."""
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq
=
protein
[
'msa'
].
shape
[
0
]
num_seq
=
protein
[
"msa"
].
shape
[
0
]
def
csum
(
x
):
def
csum
(
x
):
return
unsorted_segment_sum
(
return
unsorted_segment_sum
(
x
,
protein
[
'
extra_cluster_assignment
'
],
num_seq
x
,
protein
[
"
extra_cluster_assignment
"
],
num_seq
)
)
mask
=
protein
[
'
extra_msa_mask
'
]
mask
=
protein
[
"
extra_msa_mask
"
]
mask_counts
=
1e-6
+
protein
[
'
msa_mask
'
]
+
csum
(
mask
)
# Include center
mask_counts
=
1e-6
+
protein
[
"
msa_mask
"
]
+
csum
(
mask
)
# Include center
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
make_one_hot
(
protein
[
'
extra_msa
'
],
23
))
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
make_one_hot
(
protein
[
"
extra_msa
"
],
23
))
msa_sum
+=
make_one_hot
(
protein
[
'
msa
'
],
23
)
# Original sequence
msa_sum
+=
make_one_hot
(
protein
[
"
msa
"
],
23
)
# Original sequence
protein
[
'
cluster_profile
'
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
protein
[
"
cluster_profile
"
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
del
msa_sum
del
msa_sum
del_sum
=
csum
(
mask
*
protein
[
'
extra_deletion_matrix
'
])
del_sum
=
csum
(
mask
*
protein
[
"
extra_deletion_matrix
"
])
del_sum
+=
protein
[
'
deletion_matrix
'
]
# Original sequence
del_sum
+=
protein
[
"
deletion_matrix
"
]
# Original sequence
protein
[
'
cluster_deletion_mean
'
]
=
del_sum
/
mask_counts
protein
[
"
cluster_deletion_mean
"
]
=
del_sum
/
mask_counts
del
del_sum
del
del_sum
return
protein
return
protein
def
make_msa_mask
(
protein
):
def
make_msa_mask
(
protein
):
"""Mask features are all ones, but will later be zero-padded."""
"""Mask features are all ones, but will later be zero-padded."""
protein
[
'msa_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
,
dtype
=
torch
.
float32
)
protein
[
"msa_mask"
]
=
torch
.
ones
(
protein
[
"msa"
].
shape
,
dtype
=
torch
.
float32
)
protein
[
'msa_row_mask'
]
=
torch
.
ones
(
protein
[
'msa'
].
shape
[
0
],
dtype
=
torch
.
float32
)
protein
[
"msa_row_mask"
]
=
torch
.
ones
(
protein
[
"msa"
].
shape
[
0
],
dtype
=
torch
.
float32
)
return
protein
return
protein
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
):
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_mask
):
"""Create pseudo beta features."""
"""Create pseudo beta features."""
is_gly
=
torch
.
eq
(
aatype
,
rc
.
restype_order
[
'G'
])
is_gly
=
torch
.
eq
(
aatype
,
rc
.
restype_order
[
"G"
])
ca_idx
=
rc
.
atom_order
[
'
CA
'
]
ca_idx
=
rc
.
atom_order
[
"
CA
"
]
cb_idx
=
rc
.
atom_order
[
'
CB
'
]
cb_idx
=
rc
.
atom_order
[
"
CB
"
]
pseudo_beta
=
torch
.
where
(
pseudo_beta
=
torch
.
where
(
torch
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
torch
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
all_atom_positions
[...,
cb_idx
,
:],
)
if
all_atom_mask
is
not
None
:
if
all_atom_mask
is
not
None
:
pseudo_beta_mask
=
torch
.
where
(
pseudo_beta_mask
=
torch
.
where
(
is_gly
,
all_atom_mask
[...,
ca_idx
],
all_atom_mask
[...,
cb_idx
])
is_gly
,
all_atom_mask
[...,
ca_idx
],
all_atom_mask
[...,
cb_idx
]
)
return
pseudo_beta
,
pseudo_beta_mask
return
pseudo_beta
,
pseudo_beta_mask
else
:
else
:
return
pseudo_beta
return
pseudo_beta
@
curry1
@
curry1
def
make_pseudo_beta
(
protein
,
prefix
=
''
):
def
make_pseudo_beta
(
protein
,
prefix
=
""
):
"""Create pseudo-beta (alpha for glycine) position and mask."""
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert
prefix
in
[
''
,
'template_'
]
assert
prefix
in
[
""
,
"template_"
]
protein
[
prefix
+
'pseudo_beta'
],
protein
[
prefix
+
'pseudo_beta_mask'
]
=
(
(
pseudo_beta_fn
(
protein
[
prefix
+
"pseudo_beta"
],
protein
[
'template_aatype'
if
prefix
else
'aatype'
],
protein
[
prefix
+
"pseudo_beta_mask"
],
protein
[
prefix
+
'all_atom_positions'
],
)
=
pseudo_beta_fn
(
protein
[
'template_all_atom_mask'
if
prefix
else
'all_atom_mask'
]))
protein
[
"template_aatype"
if
prefix
else
"aatype"
],
protein
[
prefix
+
"all_atom_positions"
],
protein
[
"template_all_atom_mask"
if
prefix
else
"all_atom_mask"
],
)
return
protein
return
protein
@
curry1
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
protein
[
key
]
=
torch
.
tensor
(
value
)
return
protein
return
protein
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
ds
=
probs
.
shape
ds
=
probs
.
shape
num_classes
=
ds
[
-
1
]
num_classes
=
ds
[
-
1
]
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
distribution
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
reshape
(
probs
+
epsilon
,[
-
1
,
num_classes
])
torch
.
reshape
(
probs
+
epsilon
,
[
-
1
,
num_classes
])
)
)
counts
=
distribution
.
sample
()
counts
=
distribution
.
sample
()
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
return
torch
.
reshape
(
counts
,
ds
[:
-
1
])
def
make_hhblits_profile
(
protein
):
def
make_hhblits_profile
(
protein
):
"""Compute the HHblits MSA profile if not already present."""
"""Compute the HHblits MSA profile if not already present."""
if
'
hhblits_profile
'
in
protein
:
if
"
hhblits_profile
"
in
protein
:
return
protein
return
protein
# Compute the profile for every residue (over all MSA sequences).
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot
=
make_one_hot
(
protein
[
'
msa
'
],
22
)
msa_one_hot
=
make_one_hot
(
protein
[
"
msa
"
],
22
)
protein
[
'
hhblits_profile
'
]
=
torch
.
mean
(
msa_one_hot
,
dim
=
0
)
protein
[
"
hhblits_profile
"
]
=
torch
.
mean
(
msa_one_hot
,
dim
=
0
)
return
protein
return
protein
@
curry1
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.
,
0.
],
dtype
=
torch
.
float32
)
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.
0
,
0.
0
],
dtype
=
torch
.
float32
)
categorical_probs
=
(
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
uniform_prob
*
random_aa
config
.
profile_prob
*
protein
[
'hhblits_profile'
]
+
+
config
.
profile_prob
*
protein
[
"hhblits_profile"
]
config
.
same_prob
*
make_one_hot
(
protein
[
'msa'
],
22
))
+
config
.
same_prob
*
make_one_hot
(
protein
[
"msa"
],
22
)
)
# Put all remaining probability on [MASK] which is a new column
# Put all remaining probability on [MASK] which is a new column
pad_shapes
=
list
(
reduce
(
add
,
[(
0
,
0
)
for
_
in
range
(
len
(
categorical_probs
.
shape
))]))
pad_shapes
=
list
(
reduce
(
add
,
[(
0
,
0
)
for
_
in
range
(
len
(
categorical_probs
.
shape
))])
)
pad_shapes
[
1
]
=
1
pad_shapes
[
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
mask_prob
=
(
assert
mask_prob
>=
0.
1.0
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
)
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
categorical_probs
,
pad_shapes
,
value
=
mask_prob
)
)
sh
=
protein
[
'
msa
'
].
shape
sh
=
protein
[
"
msa
"
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
'
msa
'
])
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"
msa
"
])
# Mix real and masked MSA
# Mix real and masked MSA
protein
[
'
bert_mask
'
]
=
mask_position
.
to
(
torch
.
float32
)
protein
[
"
bert_mask
"
]
=
mask_position
.
to
(
torch
.
float32
)
protein
[
'
true_msa
'
]
=
protein
[
'
msa
'
]
protein
[
"
true_msa
"
]
=
protein
[
"
msa
"
]
protein
[
'
msa
'
]
=
bert_msa
protein
[
"
msa
"
]
=
bert_msa
return
protein
return
protein
@
curry1
@
curry1
def
make_fixed_size
(
def
make_fixed_size
(
protein
,
protein
,
...
@@ -388,7 +459,7 @@ def make_fixed_size(
...
@@ -388,7 +459,7 @@ def make_fixed_size(
msa_cluster_size
,
msa_cluster_size
,
extra_msa_size
,
extra_msa_size
,
num_res
=
0
,
num_res
=
0
,
num_templates
=
0
num_templates
=
0
,
):
):
"""Guess at the MSA and sequence dimension to make fixed size."""
"""Guess at the MSA and sequence dimension to make fixed size."""
...
@@ -401,14 +472,12 @@ def make_fixed_size(
...
@@ -401,14 +472,12 @@ def make_fixed_size(
for
k
,
v
in
protein
.
items
():
for
k
,
v
in
protein
.
items
():
# Don't transfer this to the accelerator.
# Don't transfer this to the accelerator.
if
k
==
'
extra_cluster_assignment
'
:
if
k
==
"
extra_cluster_assignment
"
:
continue
continue
shape
=
list
(
v
.
shape
)
shape
=
list
(
v
.
shape
)
schema
=
shape_schema
[
k
]
schema
=
shape_schema
[
k
]
msg
=
"Rank mismatch between shape and shape schema for"
msg
=
"Rank mismatch between shape and shape schema for"
assert
len
(
shape
)
==
len
(
schema
),
(
assert
len
(
shape
)
==
len
(
schema
),
f
"
{
msg
}
{
k
}
:
{
shape
}
vs
{
schema
}
"
f
'
{
msg
}
{
k
}
:
{
shape
}
vs
{
schema
}
'
)
pad_size
=
[
pad_size
=
[
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)
]
]
...
@@ -422,24 +491,27 @@ def make_fixed_size(
...
@@ -422,24 +491,27 @@ def make_fixed_size(
return
protein
return
protein
@
curry1
@
curry1
def
make_msa_feat
(
protein
):
def
make_msa_feat
(
protein
):
"""Create and concatenate MSA features."""
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for
# Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets.
# compatibility with domain datasets.
has_break
=
torch
.
clip
(
has_break
=
torch
.
clip
(
protein
[
'
between_segment_residues
'
].
to
(
torch
.
float32
),
0
,
1
protein
[
"
between_segment_residues
"
].
to
(
torch
.
float32
),
0
,
1
)
)
aatype_1hot
=
make_one_hot
(
protein
[
'
aatype
'
],
21
)
aatype_1hot
=
make_one_hot
(
protein
[
"
aatype
"
],
21
)
target_feat
=
[
target_feat
=
[
torch
.
unsqueeze
(
has_break
,
dim
=-
1
),
torch
.
unsqueeze
(
has_break
,
dim
=-
1
),
aatype_1hot
,
# Everyone gets the original sequence.
aatype_1hot
,
# Everyone gets the original sequence.
]
]
msa_1hot
=
make_one_hot
(
protein
[
'msa'
],
23
)
msa_1hot
=
make_one_hot
(
protein
[
"msa"
],
23
)
has_deletion
=
torch
.
clip
(
protein
[
'deletion_matrix'
],
0.
,
1.
)
has_deletion
=
torch
.
clip
(
protein
[
"deletion_matrix"
],
0.0
,
1.0
)
deletion_value
=
torch
.
atan
(
protein
[
'deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
deletion_value
=
torch
.
atan
(
protein
[
"deletion_matrix"
]
/
3.0
)
*
(
2.0
/
np
.
pi
)
msa_feat
=
[
msa_feat
=
[
msa_1hot
,
msa_1hot
,
...
@@ -447,24 +519,27 @@ def make_msa_feat(protein):
...
@@ -447,24 +519,27 @@ def make_msa_feat(protein):
torch
.
unsqueeze
(
deletion_value
,
dim
=-
1
),
torch
.
unsqueeze
(
deletion_value
,
dim
=-
1
),
]
]
if
'cluster_profile'
in
protein
:
if
"cluster_profile"
in
protein
:
deletion_mean_value
=
(
deletion_mean_value
=
torch
.
atan
(
torch
.
atan
(
protein
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
)
protein
[
"cluster_deletion_mean"
]
/
3.0
)
)
*
(
2.0
/
np
.
pi
)
msa_feat
.
extend
([
protein
[
'cluster_profile'
],
msa_feat
.
extend
(
[
protein
[
"cluster_profile"
],
torch
.
unsqueeze
(
deletion_mean_value
,
dim
=-
1
),
torch
.
unsqueeze
(
deletion_mean_value
,
dim
=-
1
),
])
]
)
if
'
extra_deletion_matrix
'
in
protein
:
if
"
extra_deletion_matrix
"
in
protein
:
protein
[
'
extra_has_deletion
'
]
=
torch
.
clip
(
protein
[
"
extra_has_deletion
"
]
=
torch
.
clip
(
protein
[
'
extra_deletion_matrix
'
],
0.
,
1.
protein
[
"
extra_deletion_matrix
"
],
0.
0
,
1.
0
)
)
protein
[
'
extra_deletion_value
'
]
=
torch
.
atan
(
protein
[
"
extra_deletion_value
"
]
=
torch
.
atan
(
protein
[
'
extra_deletion_matrix
'
]
/
3.
protein
[
"
extra_deletion_matrix
"
]
/
3.
0
)
*
(
2.
/
np
.
pi
)
)
*
(
2.
0
/
np
.
pi
)
protein
[
'
msa_feat
'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
protein
[
"
msa_feat
"
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
protein
[
'
target_feat
'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
protein
[
"
target_feat
"
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
return
protein
...
@@ -476,7 +551,7 @@ def select_feat(protein, feature_list):
...
@@ -476,7 +551,7 @@ def select_feat(protein, feature_list):
@
curry1
@
curry1
def
crop_templates
(
protein
,
max_templates
):
def
crop_templates
(
protein
,
max_templates
):
for
k
,
v
in
protein
.
items
():
for
k
,
v
in
protein
.
items
():
if
k
.
startswith
(
'
template_
'
):
if
k
.
startswith
(
"
template_
"
):
protein
[
k
]
=
v
[:
max_templates
]
protein
[
k
]
=
v
[:
max_templates
]
return
protein
return
protein
...
@@ -488,57 +563,58 @@ def make_atom14_masks(protein):
...
@@ -488,57 +563,58 @@ def make_atom14_masks(protein):
restype_atom14_mask
=
[]
restype_atom14_mask
=
[]
for
rt
in
rc
.
restypes
:
for
rt
in
rc
.
restypes
:
atom_names
=
rc
.
restype_name_to_atom14_names
[
atom_names
=
rc
.
restype_name_to_atom14_names
[
rc
.
restype_1to3
[
rt
]]
rc
.
restype_1to3
[
rt
]
restype_atom14_to_atom37
.
append
(
]
[(
rc
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
]
restype_atom14_to_atom37
.
append
([
)
(
rc
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
restype_atom37_to_atom14
.
append
(
[
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
rc
.
atom_types
for
name
in
rc
.
atom_types
])
]
)
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
(
[(
1.0
if
name
else
0.0
)
for
name
in
atom_names
]
)
# Add dummy mapping for restype 'UNK'
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
.
append
([
0.
0
]
*
14
)
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
=
torch
.
tensor
(
restype_atom14_to_atom37
,
restype_atom14_to_atom37
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
protein
[
'
aatype
'
].
device
,
device
=
protein
[
"
aatype
"
].
device
,
)
)
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
=
torch
.
tensor
(
restype_atom37_to_atom14
,
restype_atom37_to_atom14
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
protein
[
'
aatype
'
].
device
,
device
=
protein
[
"
aatype
"
].
device
,
)
)
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
=
torch
.
tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
protein
[
'
aatype
'
].
device
,
device
=
protein
[
"
aatype
"
].
device
,
)
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
'
aatype
'
]]
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
protein
[
"
aatype
"
]]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
'
aatype
'
]]
residx_atom14_mask
=
restype_atom14_mask
[
protein
[
"
aatype
"
]]
protein
[
'
atom14_atom_exists
'
]
=
residx_atom14_mask
protein
[
"
atom14_atom_exists
"
]
=
residx_atom14_mask
protein
[
'
residx_atom14_to_atom37
'
]
=
residx_atom14_to_atom37
.
long
()
protein
[
"
residx_atom14_to_atom37
"
]
=
residx_atom14_to_atom37
.
long
()
# create the gather indices for mapping back
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
'
aatype
'
]]
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
protein
[
"
aatype
"
]]
protein
[
'
residx_atom37_to_atom14
'
]
=
residx_atom37_to_atom14
.
long
()
protein
[
"
residx_atom37_to_atom14
"
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
(
restype_atom37_mask
=
torch
.
zeros
(
[
21
,
37
],
dtype
=
torch
.
float32
,
device
=
protein
[
'
aatype
'
].
device
[
21
,
37
],
dtype
=
torch
.
float32
,
device
=
protein
[
"
aatype
"
].
device
)
)
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
restype_name
=
rc
.
restype_1to3
[
restype_letter
]
restype_name
=
rc
.
restype_1to3
[
restype_letter
]
...
@@ -547,8 +623,8 @@ def make_atom14_masks(protein):
...
@@ -547,8 +623,8 @@ def make_atom14_masks(protein):
atom_type
=
rc
.
atom_order
[
atom_name
]
atom_type
=
rc
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
'
aatype
'
]]
residx_atom37_mask
=
restype_atom37_mask
[
protein
[
"
aatype
"
]]
protein
[
'
atom37_atom_exists
'
]
=
residx_atom37_mask
protein
[
"
atom37_atom_exists
"
]
=
residx_atom37_mask
return
protein
return
protein
...
@@ -570,7 +646,7 @@ def make_atom14_positions(protein):
...
@@ -570,7 +646,7 @@ def make_atom14_positions(protein):
protein
[
"all_atom_mask"
],
protein
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
residx_atom14_to_atom37
,
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
protein
[
"all_atom_mask"
].
shape
[:
-
1
])
no_batch_dims
=
len
(
protein
[
"all_atom_mask"
].
shape
[:
-
1
])
,
)
)
# Gather the ground truth positions.
# Gather the ground truth positions.
...
@@ -579,7 +655,7 @@ def make_atom14_positions(protein):
...
@@ -579,7 +655,7 @@ def make_atom14_positions(protein):
protein
[
"all_atom_positions"
],
protein
[
"all_atom_positions"
],
residx_atom14_to_atom37
,
residx_atom14_to_atom37
,
dim
=-
2
,
dim
=-
2
,
no_batch_dims
=
len
(
protein
[
"all_atom_positions"
].
shape
[:
-
2
])
no_batch_dims
=
len
(
protein
[
"all_atom_positions"
].
shape
[:
-
2
])
,
)
)
)
)
...
@@ -589,9 +665,7 @@ def make_atom14_positions(protein):
...
@@ -589,9 +665,7 @@ def make_atom14_positions(protein):
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
# alternative ground truth coordinates where the naming is swapped
restype_3
=
[
restype_3
=
[
rc
.
restype_1to3
[
res
]
for
res
in
rc
.
restypes
]
rc
.
restype_1to3
[
res
]
for
res
in
rc
.
restypes
]
restype_3
+=
[
"UNK"
]
restype_3
+=
[
"UNK"
]
# Matrices for renaming ambiguous atoms.
# Matrices for renaming ambiguous atoms.
...
@@ -599,21 +673,26 @@ def make_atom14_positions(protein):
...
@@ -599,21 +673,26 @@ def make_atom14_positions(protein):
res
:
torch
.
eye
(
res
:
torch
.
eye
(
14
,
14
,
dtype
=
protein
[
"all_atom_mask"
].
dtype
,
dtype
=
protein
[
"all_atom_mask"
].
dtype
,
device
=
protein
[
"all_atom_mask"
].
device
device
=
protein
[
"all_atom_mask"
].
device
,
)
for
res
in
restype_3
)
for
res
in
restype_3
}
}
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
torch
.
arange
(
14
,
device
=
protein
[
"all_atom_mask"
].
device
)
correspondences
=
torch
.
arange
(
14
,
device
=
protein
[
"all_atom_mask"
].
device
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
rc
.
restype_name_to_atom14_names
[
source_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
resname
].
index
(
source_atom_swap
)
source_atom_swap
target_index
=
rc
.
restype_name_to_atom14_names
[
)
resname
].
index
(
target_atom_swap
)
target_index
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
protein
[
"all_atom_mask"
].
new_zeros
((
14
,
14
))
renaming_matrix
=
protein
[
"all_atom_mask"
].
new_zeros
((
14
,
14
))
for
index
,
correspondence
in
enumerate
(
correspondences
):
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
renaming_matrix
[
index
,
correspondence
]
=
1.
0
all_matrices
[
resname
]
=
renaming_matrix
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
...
@@ -625,9 +704,7 @@ def make_atom14_positions(protein):
...
@@ -625,9 +704,7 @@ def make_atom14_positions(protein):
# Apply it to the ground truth positions. shape (num_res, 14, 3).
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions
=
torch
.
einsum
(
alternative_gt_positions
=
torch
.
einsum
(
"...rac,...rab->...rbc"
,
"...rac,...rab->...rbc"
,
residx_atom14_gt_positions
,
renaming_transform
residx_atom14_gt_positions
,
renaming_transform
)
)
protein
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
protein
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
...
@@ -635,9 +712,7 @@ def make_atom14_positions(protein):
...
@@ -635,9 +712,7 @@ def make_atom14_positions(protein):
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
# ground truth position).
alternative_gt_mask
=
torch
.
einsum
(
alternative_gt_mask
=
torch
.
einsum
(
"...ra,...rab->...rb"
,
"...ra,...rab->...rb"
,
residx_atom14_gt_mask
,
renaming_transform
residx_atom14_gt_mask
,
renaming_transform
)
)
protein
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
protein
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
...
@@ -645,19 +720,20 @@ def make_atom14_positions(protein):
...
@@ -645,19 +720,20 @@ def make_atom14_positions(protein):
restype_atom14_is_ambiguous
=
protein
[
"all_atom_mask"
].
new_zeros
((
21
,
14
))
restype_atom14_is_ambiguous
=
protein
[
"all_atom_mask"
].
new_zeros
((
21
,
14
))
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
resname
,
swap
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
rc
.
restype_order
[
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]]
rc
.
restype_3to1
[
resname
]]
atom_idx1
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_idx1
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_name1
)
atom_idx2
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_idx2
=
rc
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# From this create an ambiguous_mask for the given sequence.
# From this create an ambiguous_mask for the given sequence.
protein
[
"atom14_atom_is_ambiguous"
]
=
(
protein
[
"atom14_atom_is_ambiguous"
]
=
restype_atom14_is_ambiguous
[
restype_atom14_is_ambiguous
[
protein
[
"aatype"
]
]
protein
[
"aatype"
]
)
]
return
protein
return
protein
...
@@ -669,14 +745,14 @@ def atom37_to_frames(protein):
...
@@ -669,14 +745,14 @@ def atom37_to_frames(protein):
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
""
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'
CA
'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
"C"
,
"
CA
"
,
"N"
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'
CA
'
,
'C'
,
'O'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
"
CA
"
,
"C"
,
"O"
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]
)
:
if
rc
.
chi_angles_mask
[
restype
][
chi_idx
]:
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
restype
,
chi_idx
+
4
,
:
...
@@ -687,12 +763,12 @@ def atom37_to_frames(protein):
...
@@ -687,12 +763,12 @@ def atom37_to_frames(protein):
)
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
all_atom_mask
.
new_tensor
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookuptable
[
""
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
restype_rigidgroup_base_atom_names
,
...
@@ -702,8 +778,7 @@ def atom37_to_frames(protein):
...
@@ -702,8 +778,7 @@ def atom37_to_frames(protein):
)
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
)
)
...
@@ -739,13 +814,11 @@ def atom37_to_frames(protein):
...
@@ -739,13 +814,11 @@ def atom37_to_frames(protein):
all_atom_mask
,
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
,
)
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
...
@@ -764,9 +837,7 @@ def atom37_to_frames(protein):
...
@@ -764,9 +837,7 @@ def atom37_to_frames(protein):
)
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]]
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
...
@@ -791,11 +862,11 @@ def atom37_to_frames(protein):
...
@@ -791,11 +862,11 @@ def atom37_to_frames(protein):
gt_frames_tensor
=
gt_frames
.
to_4x4
()
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
protein
[
'
rigidgroups_gt_frames
'
]
=
gt_frames_tensor
protein
[
"
rigidgroups_gt_frames
"
]
=
gt_frames_tensor
protein
[
'
rigidgroups_gt_exists
'
]
=
gt_exists
protein
[
"
rigidgroups_gt_exists
"
]
=
gt_exists
protein
[
'
rigidgroups_group_exists
'
]
=
group_exists
protein
[
"
rigidgroups_group_exists
"
]
=
group_exists
protein
[
'
rigidgroups_group_is_ambiguous
'
]
=
residx_rigidgroup_is_ambiguous
protein
[
"
rigidgroups_group_is_ambiguous
"
]
=
residx_rigidgroup_is_ambiguous
protein
[
'
rigidgroups_alt_gt_frames
'
]
=
alt_gt_frames_tensor
protein
[
"
rigidgroups_alt_gt_frames
"
]
=
alt_gt_frames_tensor
return
protein
return
protein
...
@@ -815,10 +886,11 @@ def get_chi_atom_indices():
...
@@ -815,10 +886,11 @@ def get_chi_atom_indices():
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
atom_indices
.
append
([
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
atom_indices
.
append
(
[
0
,
0
,
0
,
0
]
)
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
...
@@ -829,7 +901,7 @@ def get_chi_atom_indices():
...
@@ -829,7 +901,7 @@ def get_chi_atom_indices():
@
curry1
@
curry1
def
atom37_to_torsion_angles
(
def
atom37_to_torsion_angles
(
protein
,
protein
,
prefix
=
''
,
prefix
=
""
,
):
):
"""
"""
Convert coordinates to torsion angles.
Convert coordinates to torsion angles.
...
@@ -873,35 +945,27 @@ def atom37_to_torsion_angles(
...
@@ -873,35 +945,27 @@ def atom37_to_torsion_angles(
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
pre_omega_atom_pos
=
torch
.
cat
(
[
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]],
prev_all_atom_positions
[...,
1
:
3
,
:],
dim
=-
2
,
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
)
phi_atom_pos
=
torch
.
cat
(
phi_atom_pos
=
torch
.
cat
(
[
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]],
prev_all_atom_positions
[...,
2
:
3
,
:],
dim
=-
2
,
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
)
psi_atom_pos
=
torch
.
cat
(
psi_atom_pos
=
torch
.
cat
(
[
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]],
all_atom_positions
[...,
:
3
,
:],
dim
=-
2
,
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
)
pre_omega_mask
=
(
pre_omega_mask
=
torch
.
prod
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
phi_mask
=
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
)
psi_mask
=
(
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
all_atom_mask
[...,
4
]
*
all_atom_mask
[...,
4
]
)
)
chi_atom_indices
=
torch
.
as_tensor
(
chi_atom_indices
=
torch
.
as_tensor
(
...
@@ -914,7 +978,7 @@ def atom37_to_torsion_angles(
...
@@ -914,7 +978,7 @@ def atom37_to_torsion_angles(
)
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
.
append
([
0.
0
,
0.
0
,
0.
0
,
0.
0
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chis_mask
=
chi_angles_mask
[
aatype
,
:]
...
@@ -923,7 +987,7 @@ def atom37_to_torsion_angles(
...
@@ -923,7 +987,7 @@ def atom37_to_torsion_angles(
all_atom_mask
,
all_atom_mask
,
atom_indices
,
atom_indices
,
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
,
)
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
...
@@ -936,7 +1000,8 @@ def atom37_to_torsion_angles(
...
@@ -936,7 +1000,8 @@ def atom37_to_torsion_angles(
phi_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
chis_atom_pos
,
],
dim
=-
3
],
dim
=-
3
,
)
)
torsion_angles_mask
=
torch
.
cat
(
torsion_angles_mask
=
torch
.
cat
(
...
@@ -945,7 +1010,8 @@ def atom37_to_torsion_angles(
...
@@ -945,7 +1010,8 @@ def atom37_to_torsion_angles(
phi_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
chis_mask
,
],
dim
=-
1
],
dim
=-
1
,
)
)
torsion_frames
=
T
.
from_3_points
(
torsion_frames
=
T
.
from_3_points
(
...
@@ -968,13 +1034,14 @@ def atom37_to_torsion_angles(
...
@@ -968,13 +1034,14 @@ def atom37_to_torsion_angles(
torch
.
square
(
torsion_angles_sin_cos
),
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
keepdims
=
True
,
)
+
1e-8
)
+
1e-8
)
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
[
1.
0
,
1.
0
,
-
1.
0
,
1.
0
,
1.
0
,
1.
0
,
1.
0
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
...
@@ -984,8 +1051,9 @@ def atom37_to_torsion_angles(
...
@@ -984,8 +1051,9 @@ def atom37_to_torsion_angles(
mirror_torsion_angles
=
torch
.
cat
(
mirror_torsion_angles
=
torch
.
cat
(
[
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
1.0
-
2.0
*
chi_is_ambiguous
,
],
dim
=-
1
],
dim
=-
1
,
)
)
alt_torsion_angles_sin_cos
=
(
alt_torsion_angles_sin_cos
=
(
...
@@ -1001,12 +1069,10 @@ def atom37_to_torsion_angles(
...
@@ -1001,12 +1069,10 @@ def atom37_to_torsion_angles(
def
get_backbone_frames
(
protein
):
def
get_backbone_frames
(
protein
):
# TODO: Verify that this is correct
# TODO: Verify that this is correct
protein
[
"backbone_affine_tensor"
]
=
(
protein
[
"backbone_affine_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
protein
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
...,
0
,
:,
:
)
]
protein
[
"backbone_affine_mask"
]
=
(
protein
[
"backbone_affine_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
protein
[
"rigidgroups_gt_exists"
][...,
0
]
)
return
protein
return
protein
...
@@ -1029,32 +1095,37 @@ def random_crop_to_size(
...
@@ -1029,32 +1095,37 @@ def random_crop_to_size(
shape_schema
,
shape_schema
,
subsample_templates
=
False
,
subsample_templates
=
False
,
seed
=
None
,
seed
=
None
,
batch_mode
=
'
clamped
'
batch_mode
=
"
clamped
"
,
):
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
'
seq_length
'
]
seq_length
=
protein
[
"
seq_length
"
]
if
'
template_mask
'
in
protein
:
if
"
template_mask
"
in
protein
:
num_templates
=
protein
[
'
template_mask
'
].
shape
[
-
1
]
num_templates
=
protein
[
"
template_mask
"
].
shape
[
-
1
]
else
:
else
:
num_templates
=
protein
[
'
aatype
'
].
new_zeros
((
1
,))
num_templates
=
protein
[
"
aatype
"
].
new_zeros
((
1
,))
num_res_crop_size
=
min
(
seq_length
,
crop_size
)
num_res_crop_size
=
min
(
seq_length
,
crop_size
)
# We want each ensemble to be cropped the same way
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
'
seq_length
'
].
device
)
g
=
torch
.
Generator
(
device
=
protein
[
"
seq_length
"
].
device
)
if
(
seed
is
not
None
)
:
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
def
_randint
(
lower
,
upper
):
def
_randint
(
lower
,
upper
):
return
int
(
torch
.
randint
(
return
int
(
lower
,
upper
,
(
1
,),
torch
.
randint
(
device
=
protein
[
'seq_length'
].
device
,
generator
=
g
lower
,
)[
0
])
upper
,
(
1
,),
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
,
)[
0
]
)
if
subsample_templates
:
if
subsample_templates
:
templates_crop_start
=
_randint
(
0
,
num_templates
+
1
)
templates_crop_start
=
_randint
(
0
,
num_templates
+
1
)
templates_select_indices
=
torch
.
randperm
(
templates_select_indices
=
torch
.
randperm
(
num_templates
,
device
=
protein
[
'
seq_length
'
].
device
,
generator
=
g
num_templates
,
device
=
protein
[
"
seq_length
"
].
device
,
generator
=
g
)
)
num_templates_crop_size
=
min
(
num_templates_crop_size
=
min
(
num_templates
-
templates_crop_start
,
max_templates
num_templates
-
templates_crop_start
,
max_templates
...
@@ -1064,9 +1135,9 @@ def random_crop_to_size(
...
@@ -1064,9 +1135,9 @@ def random_crop_to_size(
num_templates_crop_size
=
num_templates
num_templates_crop_size
=
num_templates
n
=
seq_length
-
num_res_crop_size
n
=
seq_length
-
num_res_crop_size
if
(
batch_mode
==
'
clamped
'
)
:
if
batch_mode
==
"
clamped
"
:
right_anchor
=
n
+
1
right_anchor
=
n
+
1
elif
(
batch_mode
==
'
unclamped
'
)
:
elif
batch_mode
==
"
unclamped
"
:
x
=
_randint
(
0
,
n
)
x
=
_randint
(
0
,
n
)
right_anchor
=
n
-
x
+
1
right_anchor
=
n
-
x
+
1
else
:
else
:
...
@@ -1075,20 +1146,19 @@ def random_crop_to_size(
...
@@ -1075,20 +1146,19 @@ def random_crop_to_size(
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
num_res_crop_start
=
_randint
(
0
,
right_anchor
)
for
k
,
v
in
protein
.
items
():
for
k
,
v
in
protein
.
items
():
if
(
k
not
in
shape_schema
or
if
k
not
in
shape_schema
or
(
(
'
template
'
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
]
)
"
template
"
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
]
):
):
continue
continue
# randomly permute the templates before cropping them.
# randomly permute the templates before cropping them.
if
k
.
startswith
(
'
template
'
)
and
subsample_templates
:
if
k
.
startswith
(
"
template
"
)
and
subsample_templates
:
v
=
v
[
templates_select_indices
]
v
=
v
[
templates_select_indices
]
slices
=
[]
slices
=
[]
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
v
.
shape
)):
is_num_res
=
dim_size
==
NUM_RES
is_num_res
=
(
dim_size
==
NUM_RES
)
if
i
==
0
and
k
.
startswith
(
"template"
):
if
i
==
0
and
k
.
startswith
(
'template'
):
crop_size
=
num_templates_crop_size
crop_size
=
num_templates_crop_size
crop_start
=
templates_crop_start
crop_start
=
templates_crop_start
else
:
else
:
...
@@ -1097,7 +1167,5 @@ def random_crop_to_size(
...
@@ -1097,7 +1167,5 @@ def random_crop_to_size(
slices
.
append
(
slice
(
crop_start
,
crop_start
+
crop_size
))
slices
.
append
(
slice
(
crop_start
,
crop_start
+
crop_size
))
protein
[
k
]
=
v
[
slices
]
protein
[
k
]
=
v
[
slices
]
protein
[
'seq_length'
]
=
(
protein
[
"seq_length"
]
=
protein
[
"seq_length"
].
new_tensor
(
num_res_crop_size
)
protein
[
'seq_length'
].
new_tensor
(
num_res_crop_size
)
)
return
protein
return
protein
openfold/data/feature_pipeline.py
View file @
07e64267
...
@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
...
@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
TensorDict
=
Dict
[
str
,
torch
.
Tensor
]
TensorDict
=
Dict
[
str
,
torch
.
Tensor
]
def
np_to_tensor_dict
(
def
np_to_tensor_dict
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_example
:
Mapping
[
str
,
np
.
ndarray
],
features
:
Sequence
[
str
],
features
:
Sequence
[
str
],
)
->
TensorDict
:
)
->
TensorDict
:
"""Creates dict of tensors from a dict of NumPy arrays.
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
Args:
...
@@ -54,7 +55,7 @@ def make_data_config(
...
@@ -54,7 +55,7 @@ def make_data_config(
cfg
=
copy
.
deepcopy
(
config
)
cfg
=
copy
.
deepcopy
(
config
)
mode_cfg
=
cfg
[
mode
]
mode_cfg
=
cfg
[
mode
]
with
cfg
.
unlocked
():
with
cfg
.
unlocked
():
if
(
mode_cfg
.
crop_size
is
None
)
:
if
mode_cfg
.
crop_size
is
None
:
mode_cfg
.
crop_size
=
num_res
mode_cfg
.
crop_size
=
num_res
feature_names
=
cfg
.
common
.
unsupervised_features
feature_names
=
cfg
.
common
.
unsupervised_features
...
@@ -62,7 +63,7 @@ def make_data_config(
...
@@ -62,7 +63,7 @@ def make_data_config(
if
cfg
.
common
.
use_templates
:
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
feature_names
+=
cfg
.
common
.
template_features
if
(
cfg
[
mode
].
supervised
)
:
if
cfg
[
mode
].
supervised
:
feature_names
+=
cfg
.
common
.
supervised_features
feature_names
+=
cfg
.
common
.
supervised_features
return
cfg
,
feature_names
return
cfg
,
feature_names
...
@@ -75,47 +76,47 @@ def np_example_to_features(
...
@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode
:
str
,
batch_mode
:
str
,
):
):
np_example
=
dict
(
np_example
)
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
config
,
mode
=
mode
,
num_res
=
num_res
)
if
'
deletion_matrix_int
'
in
np_example
:
if
"
deletion_matrix_int
"
in
np_example
:
np_example
[
'
deletion_matrix
'
]
=
(
np_example
[
"
deletion_matrix
"
]
=
np_example
.
pop
(
np_example
.
pop
(
'
deletion_matrix_int
'
).
astype
(
np
.
float32
)
"
deletion_matrix_int
"
)
)
.
astype
(
np
.
float32
)
if
batch_mode
==
'clamped'
:
if
batch_mode
==
"clamped"
:
np_example
[
'use_clamped_fape'
]
=
(
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
1.0
).
astype
(
np
.
float32
)
np
.
array
(
1.
).
astype
(
np
.
float32
)
elif
batch_mode
==
"unclamped"
:
)
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
0.0
).
astype
(
np
.
float32
)
elif
batch_mode
==
'unclamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
0.
).
astype
(
np
.
float32
)
)
tensor_dict
=
np_to_tensor_dict
(
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
np_example
=
np_example
,
features
=
feature_names
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
class
FeaturePipeline
:
class
FeaturePipeline
:
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
):
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
,
):
self
.
config
=
config
self
.
config
=
config
self
.
params
=
params
self
.
params
=
params
def
process_features
(
self
,
def
process_features
(
self
,
raw_features
:
FeatureDict
,
raw_features
:
FeatureDict
,
mode
:
str
=
'
train
'
,
mode
:
str
=
"
train
"
,
batch_mode
:
str
=
'
clamped
'
,
batch_mode
:
str
=
"
clamped
"
,
)
->
FeatureDict
:
)
->
FeatureDict
:
return
np_example_to_features
(
return
np_example_to_features
(
np_example
=
raw_features
,
np_example
=
raw_features
,
...
...
openfold/data/input_pipeline.py
View file @
07e64267
...
@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
...
@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
make_hhblits_profile
,
data_transforms
.
make_hhblits_profile
,
]
]
if
common_cfg
.
use_templates
:
if
common_cfg
.
use_templates
:
transforms
.
extend
([
transforms
.
extend
(
[
data_transforms
.
fix_templates_aatype
,
data_transforms
.
fix_templates_aatype
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
data_transforms
.
make_pseudo_beta
(
"template_"
),
])
]
if
(
common_cfg
.
use_template_torsion_angles
):
)
transforms
.
extend
([
if
common_cfg
.
use_template_torsion_angles
:
data_transforms
.
atom37_to_torsion_angles
(
'template_'
),
transforms
.
extend
(
])
[
data_transforms
.
atom37_to_torsion_angles
(
"template_"
),
transforms
.
extend
([
]
)
transforms
.
extend
(
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_masks
,
])
]
)
if
(
mode_cfg
.
supervised
):
if
mode_cfg
.
supervised
:
transforms
.
extend
([
transforms
.
extend
(
[
data_transforms
.
make_atom14_positions
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
''
),
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
''
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
data_transforms
.
get_chi_angles
,
])
]
)
return
transforms
return
transforms
...
@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
...
@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
)
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
)
)
)
if
'
masked_msa
'
in
common_cfg
:
if
"
masked_msa
"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
# the masked locations and secret corrupted locations.
transforms
.
append
(
transforms
.
append
(
data_transforms
.
make_masked_msa
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
mode_cfg
.
masked_msa_replace_fraction
)
)
)
)
...
@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
...
@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if
mode_cfg
.
fixed_size
:
if
mode_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
mode_cfg
.
crop_size
,
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
mode_cfg
.
max_templates
,
crop_feats
,
crop_feats
,
mode_cfg
.
subsample_templates
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
batch_mode
=
batch_mode
,
seed
=
torch
.
Generator
().
seed
()
seed
=
torch
.
Generator
().
seed
(),
))
)
transforms
.
append
(
data_transforms
.
make_fixed_size
(
)
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
crop_feats
,
pad_msa_clusters
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
common_cfg
.
max_extra_msa
,
mode_cfg
.
crop_size
,
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
mode_cfg
.
max_templates
,
))
)
)
else
:
else
:
transforms
.
append
(
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
...
@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
...
@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def
process_tensors_from_config
(
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
'
clamped
'
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
"
clamped
"
):
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
...
@@ -136,12 +147,10 @@ def process_tensors_from_config(
...
@@ -136,12 +147,10 @@ def process_tensors_from_config(
d
=
data
.
copy
()
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fn
=
compose
(
fns
)
fn
=
compose
(
fns
)
d
[
'
ensemble_index
'
]
=
i
d
[
"
ensemble_index
"
]
=
i
return
fn
(
d
)
return
fn
(
d
)
tensors
=
compose
(
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
))(
tensors
)
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
)
)(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
mode_cfg
.
num_ensemble
num_ensemble
=
mode_cfg
.
num_ensemble
...
@@ -150,8 +159,9 @@ def process_tensors_from_config(
...
@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble
*=
common_cfg
.
num_recycle
+
1
num_ensemble
*=
common_cfg
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
tensors
=
map_fn
(
torch
.
arange
(
num_ensemble
))
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_ensemble
)
)
else
:
else
:
tensors
=
tree
.
map_structure
(
lambda
x
:
x
[
None
],
tensors_0
)
tensors
=
tree
.
map_structure
(
lambda
x
:
x
[
None
],
tensors_0
)
...
...
openfold/data/mmcif_parsing.py
View file @
07e64267
...
@@ -90,6 +90,7 @@ class MmcifObject:
...
@@ -90,6 +90,7 @@ class MmcifObject:
...}}
...}}
raw_string: The raw string used to construct the MmcifObject.
raw_string: The raw string used to construct the MmcifObject.
"""
"""
file_id
:
str
file_id
:
str
header
:
PdbHeader
header
:
PdbHeader
structure
:
PdbStructure
structure
:
PdbStructure
...
@@ -107,6 +108,7 @@ class ParsingResult:
...
@@ -107,6 +108,7 @@ class ParsingResult:
parsed.
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
"""
mmcif_object
:
Optional
[
MmcifObject
]
mmcif_object
:
Optional
[
MmcifObject
]
errors
:
Mapping
[
Tuple
[
str
,
str
],
Any
]
errors
:
Mapping
[
Tuple
[
str
,
str
],
Any
]
...
@@ -115,8 +117,9 @@ class ParseError(Exception):
...
@@ -115,8 +117,9 @@ class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
"""An error indicating that an mmCIF file could not be parsed."""
def
mmcif_loop_to_list
(
prefix
:
str
,
def
mmcif_loop_to_list
(
parsed_info
:
MmCIFDict
)
->
Sequence
[
Mapping
[
str
,
str
]]:
prefix
:
str
,
parsed_info
:
MmCIFDict
)
->
Sequence
[
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
Reference for loop_ in mmCIF:
...
@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
...
@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
data
.
append
(
value
)
data
.
append
(
value
)
assert
all
([
len
(
xs
)
==
len
(
data
[
0
])
for
xs
in
data
]),
(
assert
all
([
len
(
xs
)
==
len
(
data
[
0
])
for
xs
in
data
]),
(
'mmCIF error: Not all loops are the same length: %s'
%
cols
)
"mmCIF error: Not all loops are the same length: %s"
%
cols
)
return
[
dict
(
zip
(
cols
,
xs
))
for
xs
in
zip
(
*
data
)]
return
[
dict
(
zip
(
cols
,
xs
))
for
xs
in
zip
(
*
data
)]
def
mmcif_loop_to_dict
(
prefix
:
str
,
def
mmcif_loop_to_dict
(
prefix
:
str
,
index
:
str
,
index
:
str
,
parsed_info
:
MmCIFDict
,
parsed_info
:
MmCIFDict
,
)
->
Mapping
[
str
,
Mapping
[
str
,
str
]]:
)
->
Mapping
[
str
,
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
Args:
...
@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
...
@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
def
parse
(
*
,
def
parse
(
file_id
:
str
,
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
mmcif_string
:
str
,
)
->
ParsingResult
:
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
"""Entry point, parses an mmcif_string.
"""Entry point, parses an mmcif_string.
Args:
Args:
...
@@ -188,7 +192,7 @@ def parse(*,
...
@@ -188,7 +192,7 @@ def parse(*,
try
:
try
:
parser
=
PDB
.
MMCIFParser
(
QUIET
=
True
)
parser
=
PDB
.
MMCIFParser
(
QUIET
=
True
)
handle
=
io
.
StringIO
(
mmcif_string
)
handle
=
io
.
StringIO
(
mmcif_string
)
full_structure
=
parser
.
get_structure
(
''
,
handle
)
full_structure
=
parser
.
get_structure
(
""
,
handle
)
first_model_structure
=
_get_first_model
(
full_structure
)
first_model_structure
=
_get_first_model
(
full_structure
)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
# reflected in the Biopython structure.
...
@@ -206,9 +210,12 @@ def parse(*,
...
@@ -206,9 +210,12 @@ def parse(*,
valid_chains
=
_get_protein_chains
(
parsed_info
=
parsed_info
)
valid_chains
=
_get_protein_chains
(
parsed_info
=
parsed_info
)
if
not
valid_chains
:
if
not
valid_chains
:
return
ParsingResult
(
return
ParsingResult
(
None
,
{(
file_id
,
''
):
'No protein chains found in this file.'
})
None
,
{(
file_id
,
""
):
"No protein chains found in this file."
}
seq_start_num
=
{
chain_id
:
min
([
monomer
.
num
for
monomer
in
seq
])
)
for
chain_id
,
seq
in
valid_chains
.
items
()}
seq_start_num
=
{
chain_id
:
min
([
monomer
.
num
for
monomer
in
seq
])
for
chain_id
,
seq
in
valid_chains
.
items
()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
...
@@ -217,34 +224,42 @@ def parse(*,
...
@@ -217,34 +224,42 @@ def parse(*,
mmcif_to_author_chain_id
=
{}
mmcif_to_author_chain_id
=
{}
seq_to_structure_mappings
=
{}
seq_to_structure_mappings
=
{}
for
atom
in
_get_atom_site_list
(
parsed_info
):
for
atom
in
_get_atom_site_list
(
parsed_info
):
if
atom
.
model_num
!=
'1'
:
if
atom
.
model_num
!=
"1"
:
# We only process the first model at the moment.
# We only process the first model at the moment.
continue
continue
mmcif_to_author_chain_id
[
atom
.
mmcif_chain_id
]
=
atom
.
author_chain_id
mmcif_to_author_chain_id
[
atom
.
mmcif_chain_id
]
=
atom
.
author_chain_id
if
atom
.
mmcif_chain_id
in
valid_chains
:
if
atom
.
mmcif_chain_id
in
valid_chains
:
hetflag
=
' '
hetflag
=
" "
if
atom
.
hetatm_atom
==
'
HETATM
'
:
if
atom
.
hetatm_atom
==
"
HETATM
"
:
# Water atoms are assigned a special hetflag of W in Biopython. We
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
# a residue from the Biopython structure by id.
if
atom
.
residue_name
in
(
'
HOH
'
,
'
WAT
'
):
if
atom
.
residue_name
in
(
"
HOH
"
,
"
WAT
"
):
hetflag
=
'W'
hetflag
=
"W"
else
:
else
:
hetflag
=
'
H_
'
+
atom
.
residue_name
hetflag
=
"
H_
"
+
atom
.
residue_name
insertion_code
=
atom
.
insertion_code
insertion_code
=
atom
.
insertion_code
if
not
_is_set
(
atom
.
insertion_code
):
if
not
_is_set
(
atom
.
insertion_code
):
insertion_code
=
' '
insertion_code
=
" "
position
=
ResiduePosition
(
chain_id
=
atom
.
author_chain_id
,
position
=
ResiduePosition
(
chain_id
=
atom
.
author_chain_id
,
residue_number
=
int
(
atom
.
author_seq_num
),
residue_number
=
int
(
atom
.
author_seq_num
),
insertion_code
=
insertion_code
)
insertion_code
=
insertion_code
,
seq_idx
=
int
(
atom
.
mmcif_seq_num
)
-
seq_start_num
[
atom
.
mmcif_chain_id
]
)
current
=
seq_to_structure_mappings
.
get
(
atom
.
author_chain_id
,
{})
seq_idx
=
(
current
[
seq_idx
]
=
ResidueAtPosition
(
position
=
position
,
int
(
atom
.
mmcif_seq_num
)
-
seq_start_num
[
atom
.
mmcif_chain_id
]
)
current
=
seq_to_structure_mappings
.
get
(
atom
.
author_chain_id
,
{}
)
current
[
seq_idx
]
=
ResidueAtPosition
(
position
=
position
,
name
=
atom
.
residue_name
,
name
=
atom
.
residue_name
,
is_missing
=
False
,
is_missing
=
False
,
hetflag
=
hetflag
)
hetflag
=
hetflag
,
)
seq_to_structure_mappings
[
atom
.
author_chain_id
]
=
current
seq_to_structure_mappings
[
atom
.
author_chain_id
]
=
current
# Add missing residue information to seq_to_structure_mappings.
# Add missing residue information to seq_to_structure_mappings.
...
@@ -253,19 +268,21 @@ def parse(*,
...
@@ -253,19 +268,21 @@ def parse(*,
current_mapping
=
seq_to_structure_mappings
[
author_chain
]
current_mapping
=
seq_to_structure_mappings
[
author_chain
]
for
idx
,
monomer
in
enumerate
(
seq_info
):
for
idx
,
monomer
in
enumerate
(
seq_info
):
if
idx
not
in
current_mapping
:
if
idx
not
in
current_mapping
:
current_mapping
[
idx
]
=
ResidueAtPosition
(
position
=
None
,
current_mapping
[
idx
]
=
ResidueAtPosition
(
position
=
None
,
name
=
monomer
.
id
,
name
=
monomer
.
id
,
is_missing
=
True
,
is_missing
=
True
,
hetflag
=
' '
)
hetflag
=
" "
,
)
author_chain_to_sequence
=
{}
author_chain_to_sequence
=
{}
for
chain_id
,
seq_info
in
valid_chains
.
items
():
for
chain_id
,
seq_info
in
valid_chains
.
items
():
author_chain
=
mmcif_to_author_chain_id
[
chain_id
]
author_chain
=
mmcif_to_author_chain_id
[
chain_id
]
seq
=
[]
seq
=
[]
for
monomer
in
seq_info
:
for
monomer
in
seq_info
:
code
=
SCOPData
.
protein_letters_3to1
.
get
(
monomer
.
id
,
'X'
)
code
=
SCOPData
.
protein_letters_3to1
.
get
(
monomer
.
id
,
"X"
)
seq
.
append
(
code
if
len
(
code
)
==
1
else
'X'
)
seq
.
append
(
code
if
len
(
code
)
==
1
else
"X"
)
seq
=
''
.
join
(
seq
)
seq
=
""
.
join
(
seq
)
author_chain_to_sequence
[
author_chain
]
=
seq
author_chain_to_sequence
[
author_chain
]
=
seq
mmcif_object
=
MmcifObject
(
mmcif_object
=
MmcifObject
(
...
@@ -274,11 +291,12 @@ def parse(*,
...
@@ -274,11 +291,12 @@ def parse(*,
structure
=
first_model_structure
,
structure
=
first_model_structure
,
chain_to_seqres
=
author_chain_to_sequence
,
chain_to_seqres
=
author_chain_to_sequence
,
seqres_to_structure
=
seq_to_structure_mappings
,
seqres_to_structure
=
seq_to_structure_mappings
,
raw_string
=
parsed_info
)
raw_string
=
parsed_info
,
)
return
ParsingResult
(
mmcif_object
=
mmcif_object
,
errors
=
errors
)
return
ParsingResult
(
mmcif_object
=
mmcif_object
,
errors
=
errors
)
except
Exception
as
e
:
# pylint:disable=broad-except
except
Exception
as
e
:
# pylint:disable=broad-except
errors
[(
file_id
,
''
)]
=
e
errors
[(
file_id
,
""
)]
=
e
if
not
catch_all_errors
:
if
not
catch_all_errors
:
raise
raise
return
ParsingResult
(
mmcif_object
=
None
,
errors
=
errors
)
return
ParsingResult
(
mmcif_object
=
None
,
errors
=
errors
)
...
@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
...
@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
"""Returns the first model in a Biopython structure."""
return
next
(
structure
.
get_models
())
return
next
(
structure
.
get_models
())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE
=
21
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE
=
21
def
get_release_date
(
parsed_info
:
MmCIFDict
)
->
str
:
def
get_release_date
(
parsed_info
:
MmCIFDict
)
->
str
:
"""Returns the oldest revision date."""
"""Returns the oldest revision date."""
revision_dates
=
parsed_info
[
'
_pdbx_audit_revision_history.revision_date
'
]
revision_dates
=
parsed_info
[
"
_pdbx_audit_revision_history.revision_date
"
]
return
min
(
revision_dates
)
return
min
(
revision_dates
)
...
@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
...
@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
"""Returns a basic header containing method, release date and resolution."""
header
=
{}
header
=
{}
experiments
=
mmcif_loop_to_list
(
'_exptl.'
,
parsed_info
)
experiments
=
mmcif_loop_to_list
(
"_exptl."
,
parsed_info
)
header
[
'structure_method'
]
=
','
.
join
([
header
[
"structure_method"
]
=
","
.
join
(
experiment
[
'_exptl.method'
].
lower
()
for
experiment
in
experiments
])
[
experiment
[
"_exptl.method"
].
lower
()
for
experiment
in
experiments
]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
# use this for dataset filtering over the deposition_date.
if
'
_pdbx_audit_revision_history.revision_date
'
in
parsed_info
:
if
"
_pdbx_audit_revision_history.revision_date
"
in
parsed_info
:
header
[
'
release_date
'
]
=
get_release_date
(
parsed_info
)
header
[
"
release_date
"
]
=
get_release_date
(
parsed_info
)
else
:
else
:
logging
.
warning
(
'Could not determine release_date: %s'
,
logging
.
warning
(
parsed_info
[
'_entry.id'
])
"Could not determine release_date: %s"
,
parsed_info
[
"_entry.id"
]
)
header
[
'resolution'
]
=
0.00
header
[
"resolution"
]
=
0.00
for
res_key
in
(
'_refine.ls_d_res_high'
,
'_em_3d_reconstruction.resolution'
,
for
res_key
in
(
'_reflns.d_resolution_high'
):
"_refine.ls_d_res_high"
,
"_em_3d_reconstruction.resolution"
,
"_reflns.d_resolution_high"
,
):
if
res_key
in
parsed_info
:
if
res_key
in
parsed_info
:
try
:
try
:
raw_resolution
=
parsed_info
[
res_key
][
0
]
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
'
resolution
'
]
=
float
(
raw_resolution
)
header
[
"
resolution
"
]
=
float
(
raw_resolution
)
except
ValueError
:
except
ValueError
:
logging
.
warning
(
'Invalid resolution format: %s'
,
parsed_info
[
res_key
])
logging
.
warning
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
return
header
return
header
def
_get_atom_site_list
(
parsed_info
:
MmCIFDict
)
->
Sequence
[
AtomSite
]:
def
_get_atom_site_list
(
parsed_info
:
MmCIFDict
)
->
Sequence
[
AtomSite
]:
"""Returns list of atom sites; contains data not present in the structure."""
"""Returns list of atom sites; contains data not present in the structure."""
return
[
AtomSite
(
*
site
)
for
site
in
zip
(
# pylint:disable=g-complex-comprehension
return
[
parsed_info
[
'_atom_site.label_comp_id'
],
AtomSite
(
*
site
)
parsed_info
[
'_atom_site.auth_asym_id'
],
for
site
in
zip
(
# pylint:disable=g-complex-comprehension
parsed_info
[
'_atom_site.label_asym_id'
],
parsed_info
[
"_atom_site.label_comp_id"
],
parsed_info
[
'_atom_site.auth_seq_id'
],
parsed_info
[
"_atom_site.auth_asym_id"
],
parsed_info
[
'_atom_site.label_seq_id'
],
parsed_info
[
"_atom_site.label_asym_id"
],
parsed_info
[
'_atom_site.pdbx_PDB_ins_code'
],
parsed_info
[
"_atom_site.auth_seq_id"
],
parsed_info
[
'_atom_site.group_PDB'
],
parsed_info
[
"_atom_site.label_seq_id"
],
parsed_info
[
'_atom_site.pdbx_PDB_model_num'
],
parsed_info
[
"_atom_site.pdbx_PDB_ins_code"
],
)]
parsed_info
[
"_atom_site.group_PDB"
],
parsed_info
[
"_atom_site.pdbx_PDB_model_num"
],
)
]
def
_get_protein_chains
(
def
_get_protein_chains
(
*
,
parsed_info
:
Mapping
[
str
,
Any
])
->
Mapping
[
ChainId
,
Sequence
[
Monomer
]]:
*
,
parsed_info
:
Mapping
[
str
,
Any
]
)
->
Mapping
[
ChainId
,
Sequence
[
Monomer
]]:
"""Extracts polymer information for protein chains only.
"""Extracts polymer information for protein chains only.
Args:
Args:
...
@@ -351,26 +381,29 @@ def _get_protein_chains(
...
@@ -351,26 +381,29 @@ def _get_protein_chains(
A dict mapping mmcif chain id to a list of Monomers.
A dict mapping mmcif chain id to a list of Monomers.
"""
"""
# Get polymer information for each entity in the structure.
# Get polymer information for each entity in the structure.
entity_poly_seqs
=
mmcif_loop_to_list
(
'
_entity_poly_seq.
'
,
parsed_info
)
entity_poly_seqs
=
mmcif_loop_to_list
(
"
_entity_poly_seq.
"
,
parsed_info
)
polymers
=
collections
.
defaultdict
(
list
)
polymers
=
collections
.
defaultdict
(
list
)
for
entity_poly_seq
in
entity_poly_seqs
:
for
entity_poly_seq
in
entity_poly_seqs
:
polymers
[
entity_poly_seq
[
'_entity_poly_seq.entity_id'
]].
append
(
polymers
[
entity_poly_seq
[
"_entity_poly_seq.entity_id"
]].
append
(
Monomer
(
id
=
entity_poly_seq
[
'_entity_poly_seq.mon_id'
],
Monomer
(
num
=
int
(
entity_poly_seq
[
'_entity_poly_seq.num'
])))
id
=
entity_poly_seq
[
"_entity_poly_seq.mon_id"
],
num
=
int
(
entity_poly_seq
[
"_entity_poly_seq.num"
]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
# are proteins.
chem_comps
=
mmcif_loop_to_dict
(
'
_chem_comp.
'
,
'
_chem_comp.id
'
,
parsed_info
)
chem_comps
=
mmcif_loop_to_dict
(
"
_chem_comp.
"
,
"
_chem_comp.id
"
,
parsed_info
)
# Get chains information for each entity. Necessary so that we can return a
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
# dict keyed on chain id rather than entity.
struct_asyms
=
mmcif_loop_to_list
(
'
_struct_asym.
'
,
parsed_info
)
struct_asyms
=
mmcif_loop_to_list
(
"
_struct_asym.
"
,
parsed_info
)
entity_to_mmcif_chains
=
collections
.
defaultdict
(
list
)
entity_to_mmcif_chains
=
collections
.
defaultdict
(
list
)
for
struct_asym
in
struct_asyms
:
for
struct_asym
in
struct_asyms
:
chain_id
=
struct_asym
[
'
_struct_asym.id
'
]
chain_id
=
struct_asym
[
"
_struct_asym.id
"
]
entity_id
=
struct_asym
[
'
_struct_asym.entity_id
'
]
entity_id
=
struct_asym
[
"
_struct_asym.entity_id
"
]
entity_to_mmcif_chains
[
entity_id
].
append
(
chain_id
)
entity_to_mmcif_chains
[
entity_id
].
append
(
chain_id
)
# Identify and return the valid protein chains.
# Identify and return the valid protein chains.
...
@@ -379,8 +412,12 @@ def _get_protein_chains(
...
@@ -379,8 +412,12 @@ def _get_protein_chains(
chain_ids
=
entity_to_mmcif_chains
[
entity_id
]
chain_ids
=
entity_to_mmcif_chains
[
entity_id
]
# Reject polymers without any peptide-like components, such as DNA/RNA.
# Reject polymers without any peptide-like components, such as DNA/RNA.
if
any
([
'peptide'
in
chem_comps
[
monomer
.
id
][
'_chem_comp.type'
]
if
any
(
for
monomer
in
seq_info
]):
[
"peptide"
in
chem_comps
[
monomer
.
id
][
"_chem_comp.type"
]
for
monomer
in
seq_info
]
):
for
chain_id
in
chain_ids
:
for
chain_id
in
chain_ids
:
valid_chains
[
chain_id
]
=
seq_info
valid_chains
[
chain_id
]
=
seq_info
return
valid_chains
return
valid_chains
...
@@ -388,19 +425,18 @@ def _get_protein_chains(
...
@@ -388,19 +425,18 @@ def _get_protein_chains(
def
_is_set
(
data
:
str
)
->
bool
:
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
return
data
not
in
(
"."
,
"?"
)
def
get_atom_coords
(
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
mmcif_object
:
MmcifObject
,
chain_id
:
str
chain_id
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
if
len
(
relevant_chains
)
!=
1
:
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
raise
MultipleChainsError
(
f
'
Expected exactly one chain in structure with id
{
chain_id
}
.
'
f
"
Expected exactly one chain in structure with id
{
chain_id
}
.
"
)
)
chain
=
relevant_chains
[
0
]
chain
=
relevant_chains
[
0
]
...
@@ -417,19 +453,23 @@ def get_atom_coords(
...
@@ -417,19 +453,23 @@ def get_atom_coords(
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res
=
chain
[
(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
res_at_position
.
position
.
insertion_code
,
)
]
for
atom
in
res
.
get_atoms
():
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'
SE
'
and
res
.
get_resname
()
==
'
MSE
'
:
elif
atom_name
.
upper
()
==
"
SE
"
and
res
.
get_resname
()
==
"
MSE
"
:
# Put the coords of the selenium atom in the sulphur column
# Put the coords of the selenium atom in the sulphur column
pos
[
residue_constants
.
atom_order
[
'
SD
'
]]
=
[
x
,
y
,
z
]
pos
[
residue_constants
.
atom_order
[
"
SD
"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'
SD
'
]]
=
1.0
mask
[
residue_constants
.
atom_order
[
"
SD
"
]]
=
1.0
all_atom_positions
[
res_index
]
=
pos
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
...
@@ -440,22 +480,22 @@ def get_atom_coords(
...
@@ -440,22 +480,22 @@ def get_atom_coords(
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'
.cif
'
)
):
if
f
.
endswith
(
"
.cif
"
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
'r'
)
as
fp
:
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
)
:
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
'
Could not parse
{
f
}
. Skipping...
'
)
logging
.
warning
(
f
"
Could not parse
{
f
}
. Skipping...
"
)
continue
continue
else
:
else
:
mmcif
=
mmcif
.
mmcif_object
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
=
{}
local_data
[
'
release_date
'
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"
release_date
"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
'
no_chains
'
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
local_data
[
"
no_chains
"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
data
[
file_id
]
=
local_data
with
open
(
out_path
,
'w'
)
as
fp
:
with
open
(
out_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
fp
.
write
(
json
.
dumps
(
data
))
openfold/data/parsers.py
View file @
07e64267
...
@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
...
@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
class
TemplateHit
:
"""Class representing a template hit."""
"""Class representing a template hit."""
index
:
int
index
:
int
name
:
str
name
:
str
aligned_cols
:
int
aligned_cols
:
int
...
@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
...
@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index
=
-
1
index
=
-
1
for
line
in
fasta_string
.
splitlines
():
for
line
in
fasta_string
.
splitlines
():
line
=
line
.
strip
()
line
=
line
.
strip
()
if
line
.
startswith
(
'>'
):
if
line
.
startswith
(
">"
):
index
+=
1
index
+=
1
descriptions
.
append
(
line
[
1
:])
# Remove the '>' at the beginning.
descriptions
.
append
(
line
[
1
:])
# Remove the '>' at the beginning.
sequences
.
append
(
''
)
sequences
.
append
(
""
)
continue
continue
elif
not
line
:
elif
not
line
:
continue
# Skip blank lines.
continue
# Skip blank lines.
...
@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
...
@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return
sequences
,
descriptions
return
sequences
,
descriptions
def
parse_stockholm
(
stockholm_string
:
str
def
parse_stockholm
(
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
stockholm_string
:
str
,
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
Args:
...
@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
...
@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence
=
collections
.
OrderedDict
()
name_to_sequence
=
collections
.
OrderedDict
()
for
line
in
stockholm_string
.
splitlines
():
for
line
in
stockholm_string
.
splitlines
():
line
=
line
.
strip
()
line
=
line
.
strip
()
if
not
line
or
line
.
startswith
((
'#'
,
'
//
'
)):
if
not
line
or
line
.
startswith
((
"#"
,
"
//
"
)):
continue
continue
name
,
sequence
=
line
.
split
()
name
,
sequence
=
line
.
split
()
if
name
not
in
name_to_sequence
:
if
name
not
in
name_to_sequence
:
name_to_sequence
[
name
]
=
''
name_to_sequence
[
name
]
=
""
name_to_sequence
[
name
]
+=
sequence
name_to_sequence
[
name
]
+=
sequence
msa
=
[]
msa
=
[]
deletion_matrix
=
[]
deletion_matrix
=
[]
query
=
''
query
=
""
keep_columns
=
[]
keep_columns
=
[]
for
seq_index
,
sequence
in
enumerate
(
name_to_sequence
.
values
()):
for
seq_index
,
sequence
in
enumerate
(
name_to_sequence
.
values
()):
if
seq_index
==
0
:
if
seq_index
==
0
:
# Gather the columns with gaps from the query
# Gather the columns with gaps from the query
query
=
sequence
query
=
sequence
keep_columns
=
[
i
for
i
,
res
in
enumerate
(
query
)
if
res
!=
'-'
]
keep_columns
=
[
i
for
i
,
res
in
enumerate
(
query
)
if
res
!=
"-"
]
# Remove the columns with gaps in the query from all sequences.
# Remove the columns with gaps in the query from all sequences.
aligned_sequence
=
''
.
join
([
sequence
[
c
]
for
c
in
keep_columns
])
aligned_sequence
=
""
.
join
([
sequence
[
c
]
for
c
in
keep_columns
])
msa
.
append
(
aligned_sequence
)
msa
.
append
(
aligned_sequence
)
...
@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
...
@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec
=
[]
deletion_vec
=
[]
deletion_count
=
0
deletion_count
=
0
for
seq_res
,
query_res
in
zip
(
sequence
,
query
):
for
seq_res
,
query_res
in
zip
(
sequence
,
query
):
if
seq_res
!=
'-'
or
query_res
!=
'-'
:
if
seq_res
!=
"-"
or
query_res
!=
"-"
:
if
query_res
==
'-'
:
if
query_res
==
"-"
:
deletion_count
+=
1
deletion_count
+=
1
else
:
else
:
deletion_vec
.
append
(
deletion_count
)
deletion_vec
.
append
(
deletion_count
)
...
@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
...
@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix
.
append
(
deletion_vec
)
deletion_matrix
.
append
(
deletion_vec
)
# Make the MSA matrix out of aligned (deletion-free) sequences.
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
''
,
''
,
string
.
ascii_lowercase
)
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
return
aligned_sequences
,
deletion_matrix
def
_convert_sto_seq_to_a3m
(
def
_convert_sto_seq_to_a3m
(
query_non_gaps
:
Sequence
[
bool
],
sto_seq
:
str
)
->
Iterable
[
str
]:
query_non_gaps
:
Sequence
[
bool
],
sto_seq
:
str
)
->
Iterable
[
str
]:
for
is_query_res_non_gap
,
sequence_res
in
zip
(
query_non_gaps
,
sto_seq
):
for
is_query_res_non_gap
,
sequence_res
in
zip
(
query_non_gaps
,
sto_seq
):
if
is_query_res_non_gap
:
if
is_query_res_non_gap
:
yield
sequence_res
yield
sequence_res
elif
sequence_res
!=
'-'
:
elif
sequence_res
!=
"-"
:
yield
sequence_res
.
lower
()
yield
sequence_res
.
lower
()
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
def
convert_stockholm_to_a3m
(
max_sequences
:
Optional
[
int
]
=
None
)
->
str
:
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
descriptions
=
{}
sequences
=
{}
sequences
=
{}
reached_max_sequences
=
False
reached_max_sequences
=
False
for
line
in
stockholm_format
.
splitlines
():
for
line
in
stockholm_format
.
splitlines
():
reached_max_sequences
=
max_sequences
and
len
(
sequences
)
>=
max_sequences
reached_max_sequences
=
(
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
max_sequences
and
len
(
sequences
)
>=
max_sequences
)
if
line
.
strip
()
and
not
line
.
startswith
((
"#"
,
"//"
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
# sequence parts.
seqname
,
aligned_seq
=
line
.
split
(
maxsplit
=
1
)
seqname
,
aligned_seq
=
line
.
split
(
maxsplit
=
1
)
if
seqname
not
in
sequences
:
if
seqname
not
in
sequences
:
if
reached_max_sequences
:
if
reached_max_sequences
:
continue
continue
sequences
[
seqname
]
=
''
sequences
[
seqname
]
=
""
sequences
[
seqname
]
+=
aligned_seq
sequences
[
seqname
]
+=
aligned_seq
for
line
in
stockholm_format
.
splitlines
():
for
line
in
stockholm_format
.
splitlines
():
if
line
[:
4
]
==
'
#=GS
'
:
if
line
[:
4
]
==
"
#=GS
"
:
# Description row - example format is:
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns
=
line
.
split
(
maxsplit
=
3
)
columns
=
line
.
split
(
maxsplit
=
3
)
seqname
,
feature
=
columns
[
1
:
3
]
seqname
,
feature
=
columns
[
1
:
3
]
value
=
columns
[
3
]
if
len
(
columns
)
==
4
else
''
value
=
columns
[
3
]
if
len
(
columns
)
==
4
else
""
if
feature
!=
'
DE
'
:
if
feature
!=
"
DE
"
:
continue
continue
if
reached_max_sequences
and
seqname
not
in
sequences
:
if
reached_max_sequences
and
seqname
not
in
sequences
:
continue
continue
...
@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
...
@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences
=
{}
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
'-'
for
res
in
query_sequence
]
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
''
.
join
(
a3m_sequences
[
seqname
]
=
""
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
))
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
)
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
fasta_chunks
=
(
for
k
in
a3m_sequences
)
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
return
'
\n
'
.
join
(
fasta_chunks
)
+
'
\n
'
# Include terminating newline.
for
k
in
a3m_sequences
)
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_get_hhr_line_regex_groups
(
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
match
=
re
.
match
(
regex_pattern
,
line
)
match
=
re
.
match
(
regex_pattern
,
line
)
if
match
is
None
:
if
match
is
None
:
raise
RuntimeError
(
f
'
Could not parse query line
{
line
}
'
)
raise
RuntimeError
(
f
"
Could not parse query line
{
line
}
"
)
return
match
.
groups
()
return
match
.
groups
()
def
_update_hhr_residue_indices_list
(
def
_update_hhr_residue_indices_list
(
sequence
:
str
,
start_index
:
int
,
indices_list
:
List
[
int
]):
sequence
:
str
,
start_index
:
int
,
indices_list
:
List
[
int
]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
"""Computes the relative indices for each residue with respect to the original sequence."""
counter
=
start_index
counter
=
start_index
for
symbol
in
sequence
:
for
symbol
in
sequence
:
if
symbol
==
'-'
:
if
symbol
==
"-"
:
indices_list
.
append
(
-
1
)
indices_list
.
append
(
-
1
)
else
:
else
:
indices_list
.
append
(
counter
)
indices_list
.
append
(
counter
)
...
@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
...
@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line.
# Parse the summary line.
pattern
=
(
pattern
=
(
'Probab=(.*)[
\t
]*E-value=(.*)[
\t
]*Score=(.*)[
\t
]*Aligned_cols=(.*)[
\t
'
"Probab=(.*)[
\t
]*E-value=(.*)[
\t
]*Score=(.*)[
\t
]*Aligned_cols=(.*)[
\t
"
' ]*Identities=(.*)%[
\t
]*Similarity=(.*)[
\t
]*Sum_probs=(.*)[
\t
'
" ]*Identities=(.*)%[
\t
]*Similarity=(.*)[
\t
]*Sum_probs=(.*)[
\t
"
']*Template_Neff=(.*)'
)
"]*Template_Neff=(.*)"
)
match
=
re
.
match
(
pattern
,
detailed_lines
[
2
])
match
=
re
.
match
(
pattern
,
detailed_lines
[
2
])
if
match
is
None
:
if
match
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
'Could not parse section: %s. Expected this:
\n
%s to contain summary.'
%
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
(
detailed_lines
,
detailed_lines
[
2
]))
%
(
detailed_lines
,
detailed_lines
[
2
])
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
)
neff
)
=
[
float
(
x
)
for
x
in
match
.
groups
()]
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
]
# The next section reads the detailed comparisons. These are in a 'human
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
# that with a regexp in order to deduce the fixed length used for that block.
query
=
''
query
=
""
hit_sequence
=
''
hit_sequence
=
""
indices_query
=
[]
indices_query
=
[]
indices_hit
=
[]
indices_hit
=
[]
length_block
=
None
length_block
=
None
for
line
in
detailed_lines
[
3
:]:
for
line
in
detailed_lines
[
3
:]:
# Parse the query sequence line
# Parse the query sequence line
if
(
line
.
startswith
(
'Q '
)
and
not
line
.
startswith
(
'Q ss_dssp'
)
and
if
(
not
line
.
startswith
(
'Q ss_pred'
)
and
line
.
startswith
(
"Q "
)
not
line
.
startswith
(
'Q Consensus'
)):
and
not
line
.
startswith
(
"Q ss_dssp"
)
and
not
line
.
startswith
(
"Q ss_pred"
)
and
not
line
.
startswith
(
"Q Consensus"
)
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# everything after that.
# start sequence end total_sequence_length
# start sequence end total_sequence_length
patt
=
r
'
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)
'
patt
=
r
"
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)
"
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
# Get the length of the parsed block using the start and finish indices,
# Get the length of the parsed block using the start and finish indices,
...
@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
...
@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_query
=
groups
[
1
]
delta_query
=
groups
[
1
]
end
=
int
(
groups
[
2
])
end
=
int
(
groups
[
2
])
num_insertions
=
len
([
x
for
x
in
delta_query
if
x
==
'-'
])
num_insertions
=
len
([
x
for
x
in
delta_query
if
x
==
"-"
])
length_block
=
end
-
start
+
num_insertions
length_block
=
end
-
start
+
num_insertions
assert
length_block
==
len
(
delta_query
)
assert
length_block
==
len
(
delta_query
)
...
@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
...
@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query
+=
delta_query
query
+=
delta_query
_update_hhr_residue_indices_list
(
delta_query
,
start
,
indices_query
)
_update_hhr_residue_indices_list
(
delta_query
,
start
,
indices_query
)
elif
line
.
startswith
(
'
T
'
):
elif
line
.
startswith
(
"
T
"
):
# Parse the hit sequence.
# Parse the hit sequence.
if
(
not
line
.
startswith
(
'T ss_dssp'
)
and
if
(
not
line
.
startswith
(
'T ss_pred'
)
and
not
line
.
startswith
(
"T ss_dssp"
)
not
line
.
startswith
(
'T Consensus'
)):
and
not
line
.
startswith
(
"T ss_pred"
)
and
not
line
.
startswith
(
"T Consensus"
)
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# parse everything after that.
# start sequence end total_sequence_length
# start sequence end total_sequence_length
patt
=
r
'
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)
'
patt
=
r
"
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)
"
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_hit_sequence
=
groups
[
1
]
delta_hit_sequence
=
groups
[
1
]
...
@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
...
@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list.
# Update the hit sequence and indices list.
hit_sequence
+=
delta_hit_sequence
hit_sequence
+=
delta_hit_sequence
_update_hhr_residue_indices_list
(
delta_hit_sequence
,
start
,
indices_hit
)
_update_hhr_residue_indices_list
(
delta_hit_sequence
,
start
,
indices_hit
)
return
TemplateHit
(
return
TemplateHit
(
index
=
number_of_hit
,
index
=
number_of_hit
,
...
@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
...
@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
# iterate through each paragraph to parse each hit.
block_starts
=
[
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
startswith
(
'
No
'
)]
block_starts
=
[
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
startswith
(
"
No
"
)]
hits
=
[]
hits
=
[]
if
block_starts
:
if
block_starts
:
block_starts
.
append
(
len
(
lines
))
# Add the end of the final block.
block_starts
.
append
(
len
(
lines
))
# Add the end of the final block.
for
i
in
range
(
len
(
block_starts
)
-
1
):
for
i
in
range
(
len
(
block_starts
)
-
1
):
hits
.
append
(
_parse_hhr_hit
(
lines
[
block_starts
[
i
]:
block_starts
[
i
+
1
]]))
hits
.
append
(
_parse_hhr_hit
(
lines
[
block_starts
[
i
]
:
block_starts
[
i
+
1
]])
)
return
hits
return
hits
def
parse_e_values_from_tblout
(
tblout
:
str
)
->
Dict
[
str
,
float
]:
def
parse_e_values_from_tblout
(
tblout
:
str
)
->
Dict
[
str
,
float
]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values
=
{
'
query
'
:
0
}
e_values
=
{
"
query
"
:
0
}
lines
=
[
line
for
line
in
tblout
.
splitlines
()
if
line
[
0
]
!=
'#'
]
lines
=
[
line
for
line
in
tblout
.
splitlines
()
if
line
[
0
]
!=
"#"
]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
# (5) E-value (full sequence) (numbering from 1).
...
...
openfold/data/templates.py
View file @
07e64267
...
@@ -89,29 +89,30 @@ class LengthError(PrefilterError):
...
@@ -89,29 +89,30 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES
=
{
TEMPLATE_FEATURES
=
{
'
template_aatype
'
:
np
.
int64
,
"
template_aatype
"
:
np
.
int64
,
'
template_all_atom_mask
'
:
np
.
float32
,
"
template_all_atom_mask
"
:
np
.
float32
,
'
template_all_atom_positions
'
:
np
.
float32
,
"
template_all_atom_positions
"
:
np
.
float32
,
'
template_domain_names
'
:
np
.
object
,
"
template_domain_names
"
:
np
.
object
,
'
template_sequence
'
:
np
.
object
,
"
template_sequence
"
:
np
.
object
,
'
template_sum_probs
'
:
np
.
float32
,
"
template_sum_probs
"
:
np
.
float32
,
}
}
def
_get_pdb_id_and_chain
(
hit
:
parsers
.
TemplateHit
)
->
Tuple
[
str
,
str
]:
def
_get_pdb_id_and_chain
(
hit
:
parsers
.
TemplateHit
)
->
Tuple
[
str
,
str
]:
"""Returns PDB id and chain id for an HHSearch Hit."""
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match
=
re
.
match
(
r
'
[a-zA-Z\d]{4}_[a-zA-Z0-9.]+
'
,
hit
.
name
)
id_match
=
re
.
match
(
r
"
[a-zA-Z\d]{4}_[a-zA-Z0-9.]+
"
,
hit
.
name
)
if
not
id_match
:
if
not
id_match
:
raise
ValueError
(
f
'
hit.name did not start with PDBID_chain:
{
hit
.
name
}
'
)
raise
ValueError
(
f
"
hit.name did not start with PDBID_chain:
{
hit
.
name
}
"
)
pdb_id
,
chain_id
=
id_match
.
group
(
0
).
split
(
'_'
)
pdb_id
,
chain_id
=
id_match
.
group
(
0
).
split
(
"_"
)
return
pdb_id
.
lower
(),
chain_id
return
pdb_id
.
lower
(),
chain_id
def
_is_after_cutoff
(
def
_is_after_cutoff
(
pdb_id
:
str
,
pdb_id
:
str
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
Optional
[
datetime
.
datetime
])
->
bool
:
release_date_cutoff
:
Optional
[
datetime
.
datetime
],
)
->
bool
:
"""Checks if the template date is after the release date cutoff.
"""Checks if the template date is after the release date cutoff.
Args:
Args:
...
@@ -123,13 +124,15 @@ def _is_after_cutoff(
...
@@ -123,13 +124,15 @@ def _is_after_cutoff(
True if the template release date is after the cutoff, False otherwise.
True if the template release date is after the cutoff, False otherwise.
"""
"""
if
release_date_cutoff
is
None
:
if
release_date_cutoff
is
None
:
raise
ValueError
(
'
The release_date_cutoff must not be None.
'
)
raise
ValueError
(
"
The release_date_cutoff must not be None.
"
)
if
pdb_id
in
release_dates
:
if
pdb_id
in
release_dates
:
return
release_dates
[
pdb_id
]
>
release_date_cutoff
return
release_dates
[
pdb_id
]
>
release_date_cutoff
else
:
else
:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
# we need to parse, we don't have to worry about returning True here.
logging
.
warning
(
'Template structure not in release dates dict: %s'
,
pdb_id
)
logging
.
warning
(
"Template structure not in release dates dict: %s"
,
pdb_id
)
return
False
return
False
...
@@ -140,7 +143,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
...
@@ -140,7 +143,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
for
line
in
f
:
for
line
in
f
:
line
=
line
.
strip
()
line
=
line
.
strip
()
# We skip obsolete entries that don't contain a mapping to a new entry.
# We skip obsolete entries that don't contain a mapping to a new entry.
if
line
.
startswith
(
'
OBSLTE
'
)
and
len
(
line
)
>
30
:
if
line
.
startswith
(
"
OBSLTE
"
)
and
len
(
line
)
>
30
:
# Format: Date From To
# Format: Date From To
# 'OBSLTE 31-JUL-94 116L 216L'
# 'OBSLTE 31-JUL-94 116L 216L'
from_id
=
line
[
20
:
24
].
lower
()
from_id
=
line
[
20
:
24
].
lower
()
...
@@ -152,38 +155,41 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
...
@@ -152,38 +155,41 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
def
generate_release_dates_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
def
generate_release_dates_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
dates
=
{}
dates
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'
.cif
'
)
):
if
f
.
endswith
(
"
.cif
"
):
path
=
os
.
path
.
join
(
mmcif_dir
,
f
)
path
=
os
.
path
.
join
(
mmcif_dir
,
f
)
with
open
(
path
,
'r'
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
mmcif_parsing
.
parse
(
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
)
if
(
mmcif
.
mmcif_object
is
None
)
:
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
'
Failed to parse
{
f
}
. Skipping...
'
)
logging
.
warning
(
f
"
Failed to parse
{
f
}
. Skipping...
"
)
continue
continue
mmcif
=
mmcif
.
mmcif_object
mmcif
=
mmcif
.
mmcif_object
release_date
=
mmcif
.
header
[
'
release_date
'
]
release_date
=
mmcif
.
header
[
"
release_date
"
]
dates
[
file_id
]
=
release_date
dates
[
file_id
]
=
release_date
with
open
(
out_path
,
'r'
)
as
fp
:
with
open
(
out_path
,
"r"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
dates
))
fp
.
write
(
json
.
dumps
(
dates
))
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
"""Parses release dates file, returns a mapping from PDBs to release dates."""
with
open
(
path
,
'r'
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
data
=
json
.
load
(
fp
)
data
=
json
.
load
(
fp
)
return
{
return
{
pdb
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
pdb
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
if
k
==
"release_date"
if
k
==
"release_date"
}
}
def
_assess_hhsearch_hit
(
def
_assess_hhsearch_hit
(
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
hit_pdb_code
:
str
,
...
@@ -192,7 +198,8 @@ def _assess_hhsearch_hit(
...
@@ -192,7 +198,8 @@ def _assess_hhsearch_hit(
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
max_subsequence_ratio
:
float
=
0.95
,
min_align_ratio
:
float
=
0.1
)
->
bool
:
min_align_ratio
:
float
=
0.1
,
)
->
bool
:
"""Determines if template is valid (without parsing the template mmcif file).
"""Determines if template is valid (without parsing the template mmcif file).
Args:
Args:
...
@@ -221,32 +228,42 @@ def _assess_hhsearch_hit(
...
@@ -221,32 +228,42 @@ def _assess_hhsearch_hit(
aligned_cols
=
hit
.
aligned_cols
aligned_cols
=
hit
.
aligned_cols
align_ratio
=
aligned_cols
/
len
(
query_sequence
)
align_ratio
=
aligned_cols
/
len
(
query_sequence
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
'-'
,
''
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
duplicate
=
(
length_ratio
>
max_subsequence_ratio
)
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
raise
DateError
(
f
'Date (
{
release_dates
[
hit_pdb_code
]
}
) > max template date '
raise
DateError
(
f
'(
{
release_date_cutoff
}
).'
)
f
"Date (
{
release_dates
[
hit_pdb_code
]
}
) > max template date "
f
"(
{
release_date_cutoff
}
)."
)
if
query_pdb_code
is
not
None
:
if
query_pdb_code
is
not
None
:
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
raise
PdbIdError
(
'
PDB code identical to Query PDB code.
'
)
raise
PdbIdError
(
"
PDB code identical to Query PDB code.
"
)
if
align_ratio
<=
min_align_ratio
:
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
'Proportion of residues aligned to query too small. '
raise
AlignRatioError
(
f
'Align ratio:
{
align_ratio
}
.'
)
"Proportion of residues aligned to query too small. "
f
"Align ratio:
{
align_ratio
}
."
)
if
duplicate
:
if
duplicate
:
raise
DuplicateError
(
'Template is an exact subsequence of query with large '
raise
DuplicateError
(
f
'coverage. Length ratio:
{
length_ratio
}
.'
)
"Template is an exact subsequence of query with large "
f
"coverage. Length ratio:
{
length_ratio
}
."
)
if
len
(
template_sequence
)
<
10
:
if
len
(
template_sequence
)
<
10
:
raise
LengthError
(
f
'Template too short. Length:
{
len
(
template_sequence
)
}
.'
)
raise
LengthError
(
f
"Template too short. Length:
{
len
(
template_sequence
)
}
."
)
return
True
return
True
...
@@ -254,7 +271,8 @@ def _assess_hhsearch_hit(
...
@@ -254,7 +271,8 @@ def _assess_hhsearch_hit(
def
_find_template_in_pdb
(
def
_find_template_in_pdb
(
template_chain_id
:
str
,
template_chain_id
:
str
,
template_sequence
:
str
,
template_sequence
:
str
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
)
->
Tuple
[
str
,
str
,
int
]:
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
)
->
Tuple
[
str
,
str
,
int
]:
"""Tries to find the template chain in the given pdb file.
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
This method tries the three following things in order:
...
@@ -286,33 +304,42 @@ def _find_template_in_pdb(
...
@@ -286,33 +304,42 @@ def _find_template_in_pdb(
chain_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
template_chain_id
)
chain_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
template_chain_id
)
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
logging
.
info
(
logging
.
info
(
'Found an exact template match %s_%s.'
,
pdb_id
,
template_chain_id
)
"Found an exact template match %s_%s."
,
pdb_id
,
template_chain_id
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
return
chain_sequence
,
template_chain_id
,
mapping_offset
return
chain_sequence
,
template_chain_id
,
mapping_offset
# Try if there is an exact match in the (sub)sequence only.
# Try if there is an exact match in the (sub)sequence only.
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
logging
.
info
(
'
Found a sequence-only match %s_%s.
'
,
pdb_id
,
chain_id
)
logging
.
info
(
"
Found a sequence-only match %s_%s.
"
,
pdb_id
,
chain_id
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
return
chain_sequence
,
chain_id
,
mapping_offset
return
chain_sequence
,
chain_id
,
mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex
=
[
'.'
if
aa
==
'X'
else
'
(?:%s|X)
'
%
aa
for
aa
in
template_sequence
]
regex
=
[
"."
if
aa
==
"X"
else
"
(?:%s|X)
"
%
aa
for
aa
in
template_sequence
]
regex
=
re
.
compile
(
''
.
join
(
regex
))
regex
=
re
.
compile
(
""
.
join
(
regex
))
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
match
=
re
.
search
(
regex
,
chain_sequence
)
match
=
re
.
search
(
regex
,
chain_sequence
)
if
match
:
if
match
:
logging
.
info
(
'Found a fuzzy sequence-only match %s_%s.'
,
pdb_id
,
chain_id
)
logging
.
info
(
"Found a fuzzy sequence-only match %s_%s."
,
pdb_id
,
chain_id
)
mapping_offset
=
match
.
start
()
mapping_offset
=
match
.
start
()
return
chain_sequence
,
chain_id
,
mapping_offset
return
chain_sequence
,
chain_id
,
mapping_offset
# No hits, raise an error.
# No hits, raise an error.
raise
SequenceNotInTemplateError
(
raise
SequenceNotInTemplateError
(
'Could not find the template sequence in %s_%s. Template sequence: %s, '
"Could not find the template sequence in %s_%s. Template sequence: %s, "
'chain_to_seqres: %s'
%
(
pdb_id
,
template_chain_id
,
template_sequence
,
"chain_to_seqres: %s"
mmcif_object
.
chain_to_seqres
))
%
(
pdb_id
,
template_chain_id
,
template_sequence
,
mmcif_object
.
chain_to_seqres
,
)
)
def
_realign_pdb_template_to_query
(
def
_realign_pdb_template_to_query
(
...
@@ -320,7 +347,8 @@ def _realign_pdb_template_to_query(
...
@@ -320,7 +347,8 @@ def _realign_pdb_template_to_query(
template_chain_id
:
str
,
template_chain_id
:
str
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
old_mapping
:
Mapping
[
int
,
int
],
old_mapping
:
Mapping
[
int
,
int
],
kalign_binary_path
:
str
)
->
Tuple
[
str
,
Mapping
[
int
,
int
]]:
kalign_binary_path
:
str
,
)
->
Tuple
[
str
,
Mapping
[
int
,
int
]]:
"""Aligns template from the mmcif_object to the query.
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
In case PDB70 contains a different version of the template sequence, we need
...
@@ -361,76 +389,104 @@ def _realign_pdb_template_to_query(
...
@@ -361,76 +389,104 @@ def _realign_pdb_template_to_query(
"""
"""
aligner
=
kalign
.
Kalign
(
binary_path
=
kalign_binary_path
)
aligner
=
kalign
.
Kalign
(
binary_path
=
kalign_binary_path
)
new_template_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
new_template_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
template_chain_id
,
''
)
template_chain_id
,
""
)
# Sometimes the template chain id is unknown. But if there is only a single
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
# sequence within the mmcif_object, it is safe to assume it is that one.
if
not
new_template_sequence
:
if
not
new_template_sequence
:
if
len
(
mmcif_object
.
chain_to_seqres
)
==
1
:
if
len
(
mmcif_object
.
chain_to_seqres
)
==
1
:
logging
.
info
(
'Could not find %s in %s, but there is only 1 sequence, so '
logging
.
info
(
'using that one.'
,
"Could not find %s in %s, but there is only 1 sequence, so "
"using that one."
,
template_chain_id
,
template_chain_id
,
mmcif_object
.
file_id
)
mmcif_object
.
file_id
,
new_template_sequence
=
list
(
mmcif_object
.
chain_to_seqres
.
values
())[
0
]
)
new_template_sequence
=
list
(
mmcif_object
.
chain_to_seqres
.
values
())[
0
]
else
:
else
:
raise
QueryToTemplateAlignError
(
raise
QueryToTemplateAlignError
(
f
'Could not find chain
{
template_chain_id
}
in
{
mmcif_object
.
file_id
}
. '
f
"Could not find chain
{
template_chain_id
}
in
{
mmcif_object
.
file_id
}
. "
'If there are no mmCIF parsing errors, it is possible it was not a '
"If there are no mmCIF parsing errors, it is possible it was not a "
'protein chain.'
)
"protein chain."
)
try
:
try
:
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
]))
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
raise
QueryToTemplateAlignError
(
'Could not align old template %s to template %s (%s_%s). Error: %s'
%
"Could not align old template %s to template %s (%s_%s). Error: %s"
(
old_template_sequence
,
new_template_sequence
,
mmcif_object
.
file_id
,
%
(
template_chain_id
,
str
(
e
)))
old_template_sequence
,
new_template_sequence
,
mmcif_object
.
file_id
,
template_chain_id
,
str
(
e
),
)
)
logging
.
info
(
'Old aligned template: %s
\n
New aligned template: %s'
,
logging
.
info
(
old_aligned_template
,
new_aligned_template
)
"Old aligned template: %s
\n
New aligned template: %s"
,
old_aligned_template
,
new_aligned_template
,
)
old_to_new_template_mapping
=
{}
old_to_new_template_mapping
=
{}
old_template_index
=
-
1
old_template_index
=
-
1
new_template_index
=
-
1
new_template_index
=
-
1
num_same
=
0
num_same
=
0
for
old_template_aa
,
new_template_aa
in
zip
(
for
old_template_aa
,
new_template_aa
in
zip
(
old_aligned_template
,
new_aligned_template
):
old_aligned_template
,
new_aligned_template
if
old_template_aa
!=
'-'
:
):
if
old_template_aa
!=
"-"
:
old_template_index
+=
1
old_template_index
+=
1
if
new_template_aa
!=
'-'
:
if
new_template_aa
!=
"-"
:
new_template_index
+=
1
new_template_index
+=
1
if
old_template_aa
!=
'-'
and
new_template_aa
!=
'-'
:
if
old_template_aa
!=
"-"
and
new_template_aa
!=
"-"
:
old_to_new_template_mapping
[
old_template_index
]
=
new_template_index
old_to_new_template_mapping
[
old_template_index
]
=
new_template_index
if
old_template_aa
==
new_template_aa
:
if
old_template_aa
==
new_template_aa
:
num_same
+=
1
num_same
+=
1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if
float
(
num_same
)
/
min
(
if
(
len
(
old_template_sequence
),
len
(
new_template_sequence
))
<
0.9
:
float
(
num_same
)
/
min
(
len
(
old_template_sequence
),
len
(
new_template_sequence
))
<
0.9
):
raise
QueryToTemplateAlignError
(
raise
QueryToTemplateAlignError
(
'Insufficient similarity of the sequence in the database: %s to the '
"Insufficient similarity of the sequence in the database: %s to the "
'actual sequence in the mmCIF file %s_%s: %s. We require at least '
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
'90 %% similarity wrt to the shorter of the sequences. This is not a '
"90 %% similarity wrt to the shorter of the sequences. This is not a "
'problem unless you think this is a template that should be included.'
%
"problem unless you think this is a template that should be included."
(
old_template_sequence
,
mmcif_object
.
file_id
,
template_chain_id
,
%
(
new_template_sequence
))
old_template_sequence
,
mmcif_object
.
file_id
,
template_chain_id
,
new_template_sequence
,
)
)
new_query_to_template_mapping
=
{}
new_query_to_template_mapping
=
{}
for
query_index
,
old_template_index
in
old_mapping
.
items
():
for
query_index
,
old_template_index
in
old_mapping
.
items
():
new_query_to_template_mapping
[
query_index
]
=
(
new_query_to_template_mapping
[
old_to_new_template_mapping
.
get
(
old_template_index
,
-
1
))
query_index
]
=
old_to_new_template_mapping
.
get
(
old_template_index
,
-
1
)
new_template_sequence
=
new_template_sequence
.
replace
(
'-'
,
''
)
new_template_sequence
=
new_template_sequence
.
replace
(
"-"
,
""
)
return
new_template_sequence
,
new_query_to_template_mapping
return
new_template_sequence
,
new_query_to_template_mapping
def
_check_residue_distances
(
all_positions
:
np
.
ndarray
,
def
_check_residue_distances
(
all_positions
:
np
.
ndarray
,
all_positions_mask
:
np
.
ndarray
,
all_positions_mask
:
np
.
ndarray
,
max_ca_ca_distance
:
float
):
max_ca_ca_distance
:
float
,
):
"""Checks if the distance between unmasked neighbor residues is ok."""
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position
=
residue_constants
.
atom_order
[
'
CA
'
]
ca_position
=
residue_constants
.
atom_order
[
"
CA
"
]
prev_is_unmasked
=
False
prev_is_unmasked
=
False
prev_calpha
=
None
prev_calpha
=
None
for
i
,
(
coords
,
mask
)
in
enumerate
(
zip
(
all_positions
,
all_positions_mask
)):
for
i
,
(
coords
,
mask
)
in
enumerate
(
zip
(
all_positions
,
all_positions_mask
)):
...
@@ -441,8 +497,9 @@ def _check_residue_distances(all_positions: np.ndarray,
...
@@ -441,8 +497,9 @@ def _check_residue_distances(all_positions: np.ndarray,
distance
=
np
.
linalg
.
norm
(
this_calpha
-
prev_calpha
)
distance
=
np
.
linalg
.
norm
(
this_calpha
-
prev_calpha
)
if
distance
>
max_ca_ca_distance
:
if
distance
>
max_ca_ca_distance
:
raise
CaDistanceError
(
raise
CaDistanceError
(
'The distance between residues %d and %d is %f > limit %f.'
%
(
"The distance between residues %d and %d is %f > limit %f."
i
,
i
+
1
,
distance
,
max_ca_ca_distance
))
%
(
i
,
i
+
1
,
distance
,
max_ca_ca_distance
)
)
prev_calpha
=
this_calpha
prev_calpha
=
this_calpha
prev_is_unmasked
=
this_is_unmasked
prev_is_unmasked
=
this_is_unmasked
...
@@ -450,7 +507,8 @@ def _check_residue_distances(all_positions: np.ndarray,
...
@@ -450,7 +507,8 @@ def _check_residue_distances(all_positions: np.ndarray,
def
_get_atom_positions
(
def
_get_atom_positions
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
auth_chain_id
:
str
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
max_ca_ca_distance
:
float
,
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
...
@@ -469,7 +527,8 @@ def _extract_template_features(
...
@@ -469,7 +527,8 @@ def _extract_template_features(
template_sequence
:
str
,
template_sequence
:
str
,
query_sequence
:
str
,
query_sequence
:
str
,
template_chain_id
:
str
,
template_chain_id
:
str
,
kalign_binary_path
:
str
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
kalign_binary_path
:
str
,
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
"""Parses atom positions in the target structure and aligns with the query.
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
Atoms for each residue in the template structure are indexed to coincide
...
@@ -509,21 +568,25 @@ def _extract_template_features(
...
@@ -509,21 +568,25 @@ def _extract_template_features(
unmasked residues.
unmasked residues.
"""
"""
if
mmcif_object
is
None
or
not
mmcif_object
.
chain_to_seqres
:
if
mmcif_object
is
None
or
not
mmcif_object
.
chain_to_seqres
:
raise
NoChainsError
(
'No chains in PDB: %s_%s'
%
(
pdb_id
,
template_chain_id
))
raise
NoChainsError
(
"No chains in PDB: %s_%s"
%
(
pdb_id
,
template_chain_id
)
)
warning
=
None
warning
=
None
try
:
try
:
seqres
,
chain_id
,
mapping_offset
=
_find_template_in_pdb
(
seqres
,
chain_id
,
mapping_offset
=
_find_template_in_pdb
(
template_chain_id
=
template_chain_id
,
template_chain_id
=
template_chain_id
,
template_sequence
=
template_sequence
,
template_sequence
=
template_sequence
,
mmcif_object
=
mmcif_object
)
mmcif_object
=
mmcif_object
,
)
except
SequenceNotInTemplateError
:
except
SequenceNotInTemplateError
:
# If PDB70 contains a different version of the template, we use the sequence
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
# from the mmcif_object.
chain_id
=
template_chain_id
chain_id
=
template_chain_id
warning
=
(
warning
=
(
f
'The exact sequence
{
template_sequence
}
was not found in '
f
"The exact sequence
{
template_sequence
}
was not found in "
f
'
{
pdb_id
}
_
{
chain_id
}
. Realigning the template to the actual sequence.'
)
f
"
{
pdb_id
}
_
{
chain_id
}
. Realigning the template to the actual sequence."
)
logging
.
warning
(
warning
)
logging
.
warning
(
warning
)
# This throws an exception if it fails to realign the hit.
# This throws an exception if it fails to realign the hit.
seqres
,
mapping
=
_realign_pdb_template_to_query
(
seqres
,
mapping
=
_realign_pdb_template_to_query
(
...
@@ -531,9 +594,15 @@ def _extract_template_features(
...
@@ -531,9 +594,15 @@ def _extract_template_features(
template_chain_id
=
template_chain_id
,
template_chain_id
=
template_chain_id
,
mmcif_object
=
mmcif_object
,
mmcif_object
=
mmcif_object
,
old_mapping
=
mapping
,
old_mapping
=
mapping
,
kalign_binary_path
=
kalign_binary_path
)
kalign_binary_path
=
kalign_binary_path
,
logging
.
info
(
'Sequence in %s_%s: %s successfully realigned to %s'
,
)
pdb_id
,
chain_id
,
template_sequence
,
seqres
)
logging
.
info
(
"Sequence in %s_%s: %s successfully realigned to %s"
,
pdb_id
,
chain_id
,
template_sequence
,
seqres
,
)
# The template sequence changed.
# The template sequence changed.
template_sequence
=
seqres
template_sequence
=
seqres
# No mapping offset, the query is aligned to the actual sequence.
# No mapping offset, the query is aligned to the actual sequence.
...
@@ -543,13 +612,16 @@ def _extract_template_features(
...
@@ -543,13 +612,16 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
# they're really really bad.
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
)
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
)
except
(
CaDistanceError
,
KeyError
)
as
ex
:
except
(
CaDistanceError
,
KeyError
)
as
ex
:
raise
NoAtomDataInTemplateError
(
raise
NoAtomDataInTemplateError
(
'
Could not get atom data (%s_%s): %s
'
%
(
pdb_id
,
chain_id
,
str
(
ex
))
"
Could not get atom data (%s_%s): %s
"
%
(
pdb_id
,
chain_id
,
str
(
ex
))
)
from
ex
)
from
ex
all_atom_positions
=
np
.
split
(
all_atom_positions
,
all_atom_positions
.
shape
[
0
])
all_atom_positions
=
np
.
split
(
all_atom_positions
,
all_atom_positions
.
shape
[
0
]
)
all_atom_masks
=
np
.
split
(
all_atom_mask
,
all_atom_mask
.
shape
[
0
])
all_atom_masks
=
np
.
split
(
all_atom_mask
,
all_atom_mask
.
shape
[
0
])
output_templates_sequence
=
[]
output_templates_sequence
=
[]
...
@@ -559,9 +631,12 @@ def _extract_template_features(
...
@@ -559,9 +631,12 @@ def _extract_template_features(
for
_
in
query_sequence
:
for
_
in
query_sequence
:
# Residues in the query_sequence that are not in the template_sequence:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions
.
append
(
templates_all_atom_positions
.
append
(
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
)))
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
templates_all_atom_masks
.
append
(
np
.
zeros
(
residue_constants
.
atom_type_num
))
)
output_templates_sequence
.
append
(
'-'
)
templates_all_atom_masks
.
append
(
np
.
zeros
(
residue_constants
.
atom_type_num
)
)
output_templates_sequence
.
append
(
"-"
)
for
k
,
v
in
mapping
.
items
():
for
k
,
v
in
mapping
.
items
():
template_index
=
v
+
mapping_offset
template_index
=
v
+
mapping_offset
...
@@ -572,24 +647,33 @@ def _extract_template_features(
...
@@ -572,24 +647,33 @@ def _extract_template_features(
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if
np
.
sum
(
templates_all_atom_masks
)
<
5
:
if
np
.
sum
(
templates_all_atom_masks
)
<
5
:
raise
TemplateAtomMaskAllZerosError
(
raise
TemplateAtomMaskAllZerosError
(
'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d'
%
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
(
pdb_id
,
chain_id
,
min
(
mapping
.
values
())
+
mapping_offset
,
%
(
max
(
mapping
.
values
())
+
mapping_offset
))
pdb_id
,
chain_id
,
min
(
mapping
.
values
())
+
mapping_offset
,
max
(
mapping
.
values
())
+
mapping_offset
,
)
)
output_templates_sequence
=
''
.
join
(
output_templates_sequence
)
output_templates_sequence
=
""
.
join
(
output_templates_sequence
)
templates_aatype
=
residue_constants
.
sequence_to_onehot
(
templates_aatype
=
residue_constants
.
sequence_to_onehot
(
output_templates_sequence
,
residue_constants
.
HHBLITS_AA_TO_ID
)
output_templates_sequence
,
residue_constants
.
HHBLITS_AA_TO_ID
)
return
(
return
(
{
{
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
"template_all_atom_positions"
:
np
.
array
(
'template_all_atom_mask'
:
np
.
array
(
templates_all_atom_masks
),
templates_all_atom_positions
'template_sequence'
:
output_templates_sequence
.
encode
(),
),
'template_aatype'
:
np
.
array
(
templates_aatype
),
"template_all_atom_mask"
:
np
.
array
(
templates_all_atom_masks
),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
"template_sequence"
:
output_templates_sequence
.
encode
(),
"template_aatype"
:
np
.
array
(
templates_aatype
),
"template_domain_names"
:
f
"
{
pdb_id
.
lower
()
}
_
{
chain_id
}
"
.
encode
(),
},
},
warning
)
warning
,
)
def
_build_query_to_hit_index_mapping
(
def
_build_query_to_hit_index_mapping
(
...
@@ -597,7 +681,8 @@ def _build_query_to_hit_index_mapping(
...
@@ -597,7 +681,8 @@ def _build_query_to_hit_index_mapping(
hit_sequence
:
str
,
hit_sequence
:
str
,
indices_hit
:
Sequence
[
int
],
indices_hit
:
Sequence
[
int
],
indices_query
:
Sequence
[
int
],
indices_query
:
Sequence
[
int
],
original_query_sequence
:
str
)
->
Mapping
[
int
,
int
]:
original_query_sequence
:
str
,
)
->
Mapping
[
int
,
int
]:
"""Gets mapping from indices in original query sequence to indices in the hit.
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
hit_query_sequence and hit_sequence are two aligned sequences containing gap
...
@@ -624,15 +709,15 @@ def _build_query_to_hit_index_mapping(
...
@@ -624,15 +709,15 @@ def _build_query_to_hit_index_mapping(
return
{}
return
{}
# Remove gaps and find the offset of hit.query relative to original query.
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence
=
hit_query_sequence
.
replace
(
'-'
,
''
)
hhsearch_query_sequence
=
hit_query_sequence
.
replace
(
"-"
,
""
)
hit_sequence
=
hit_sequence
.
replace
(
'-'
,
''
)
hit_sequence
=
hit_sequence
.
replace
(
"-"
,
""
)
hhsearch_query_offset
=
original_query_sequence
.
find
(
hhsearch_query_sequence
)
hhsearch_query_offset
=
original_query_sequence
.
find
(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx
=
min
(
x
for
x
in
indices_hit
if
x
>
-
1
)
min_idx
=
min
(
x
for
x
in
indices_hit
if
x
>
-
1
)
fixed_indices_hit
=
[
fixed_indices_hit
=
[
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_hit
]
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_hit
]
min_idx
=
min
(
x
for
x
in
indices_query
if
x
>
-
1
)
min_idx
=
min
(
x
for
x
in
indices_query
if
x
>
-
1
)
fixed_indices_query
=
[
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_query
]
fixed_indices_query
=
[
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_query
]
...
@@ -641,8 +726,9 @@ def _build_query_to_hit_index_mapping(
...
@@ -641,8 +726,9 @@ def _build_query_to_hit_index_mapping(
mapping
=
{}
mapping
=
{}
for
q_i
,
q_t
in
zip
(
fixed_indices_query
,
fixed_indices_hit
):
for
q_i
,
q_t
in
zip
(
fixed_indices_query
,
fixed_indices_hit
):
if
q_t
!=
-
1
and
q_i
!=
-
1
:
if
q_t
!=
-
1
and
q_i
!=
-
1
:
if
(
q_t
>=
len
(
hit_sequence
)
or
if
q_t
>=
len
(
hit_sequence
)
or
q_i
+
hhsearch_query_offset
>=
len
(
q_i
+
hhsearch_query_offset
>=
len
(
original_query_sequence
)):
original_query_sequence
):
continue
continue
mapping
[
q_i
+
hhsearch_query_offset
]
=
q_t
mapping
[
q_i
+
hhsearch_query_offset
]
=
q_t
...
@@ -665,7 +751,8 @@ def _process_single_hit(
...
@@ -665,7 +751,8 @@ def _process_single_hit(
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
kalign_binary_path
:
str
,
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
)
->
SingleHitResult
:
strict_error_check
:
bool
=
False
,
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
...
@@ -682,41 +769,56 @@ def _process_single_hit(
...
@@ -682,41 +769,56 @@ def _process_single_hit(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
query_pdb_code
=
query_pdb_code
,
release_dates
=
release_dates
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
)
release_date_cutoff
=
max_template_date
,
)
except
PrefilterError
as
e
:
except
PrefilterError
as
e
:
msg
=
f
'
hit
{
hit_pdb_code
}
_
{
hit_chain_id
}
did not pass prefilter:
{
str
(
e
)
}
'
msg
=
f
"
hit
{
hit_pdb_code
}
_
{
hit_chain_id
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
'
%s: %s
'
,
query_pdb_code
,
msg
)
logging
.
info
(
"
%s: %s
"
,
query_pdb_code
,
msg
)
if
strict_error_check
and
isinstance
(
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)):
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
SingleHitResult
(
features
=
None
,
error
=
msg
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
msg
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
mapping
=
_build_query_to_hit_index_mapping
(
mapping
=
_build_query_to_hit_index_mapping
(
hit
.
query
,
hit
.
hit_sequence
,
hit
.
indices_hit
,
hit
.
indices_query
,
hit
.
query
,
query_sequence
)
hit
.
hit_sequence
,
hit
.
indices_hit
,
hit
.
indices_query
,
query_sequence
,
)
# The mapping is from the query to the actual hit sequence, so we need to
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
# remove gaps (which regardless have a missing confidence score).
template_sequence
=
hit
.
hit_sequence
.
replace
(
'-'
,
''
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
cif_path
=
os
.
path
.
join
(
mmcif_dir
,
hit_pdb_code
+
'.cif'
)
cif_path
=
os
.
path
.
join
(
mmcif_dir
,
hit_pdb_code
+
".cif"
)
logging
.
info
(
'Reading PDB entry from %s. Query: %s, template: %s'
,
logging
.
info
(
cif_path
,
query_sequence
,
template_sequence
)
"Reading PDB entry from %s. Query: %s, template: %s"
,
cif_path
,
query_sequence
,
template_sequence
,
)
# Fail if we can't find the mmCIF file.
# Fail if we can't find the mmCIF file.
with
open
(
cif_path
,
'r'
)
as
cif_file
:
with
open
(
cif_path
,
"r"
)
as
cif_file
:
cif_string
=
cif_file
.
read
()
cif_string
=
cif_file
.
read
()
parsing_result
=
mmcif_parsing
.
parse
(
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
)
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
)
if
parsing_result
.
mmcif_object
is
not
None
:
if
parsing_result
.
mmcif_object
is
not
None
:
hit_release_date
=
datetime
.
datetime
.
strptime
(
hit_release_date
=
datetime
.
datetime
.
strptime
(
parsing_result
.
mmcif_object
.
header
[
'release_date'
],
'%Y-%m-%d'
)
parsing_result
.
mmcif_object
.
header
[
"release_date"
],
"%Y-%m-%d"
)
if
hit_release_date
>
max_template_date
:
if
hit_release_date
>
max_template_date
:
error
=
(
'Template %s date (%s) > max template date (%s).'
%
error
=
"Template %s date (%s) > max template date (%s)."
%
(
(
hit_pdb_code
,
hit_release_date
,
max_template_date
))
hit_pdb_code
,
hit_release_date
,
max_template_date
,
)
if
strict_error_check
:
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
else
:
else
:
...
@@ -731,31 +833,52 @@ def _process_single_hit(
...
@@ -731,31 +833,52 @@ def _process_single_hit(
template_sequence
=
template_sequence
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
hit_chain_id
,
template_chain_id
=
hit_chain_id
,
kalign_binary_path
=
kalign_binary_path
)
kalign_binary_path
=
kalign_binary_path
,
features
[
'template_sum_probs'
]
=
[
hit
.
sum_probs
]
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
# computed. In such case the mmCIF parsing errors are not relevant.
return
SingleHitResult
(
return
SingleHitResult
(
features
=
features
,
error
=
None
,
warning
=
realign_warning
)
features
=
features
,
error
=
None
,
warning
=
realign_warning
except
(
NoChainsError
,
NoAtomDataInTemplateError
,
)
TemplateAtomMaskAllZerosError
)
as
e
:
except
(
NoChainsError
,
NoAtomDataInTemplateError
,
TemplateAtomMaskAllZerosError
,
)
as
e
:
# These 3 errors indicate missing mmCIF experimental data rather than a
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
# problem with the template search, so turn them into warnings.
warning
=
(
'%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
warning
=
(
'%s, mmCIF parsing errors: %s'
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
"%s, mmCIF parsing errors: %s"
str
(
e
),
parsing_result
.
errors
))
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
,
)
)
if
strict_error_check
:
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
warning
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
warning
,
warning
=
None
)
else
:
else
:
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
warning
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
warning
)
except
Error
as
e
:
except
Error
as
e
:
error
=
(
'%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
error
=
(
'%s, mmCIF parsing errors: %s'
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
"%s, mmCIF parsing errors: %s"
str
(
e
),
parsing_result
.
errors
))
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
,
)
)
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
...
@@ -777,7 +900,8 @@ class TemplateHitFeaturizer:
...
@@ -777,7 +900,8 @@ class TemplateHitFeaturizer:
kalign_binary_path
:
str
,
kalign_binary_path
:
str
,
release_dates_path
:
Optional
[
str
],
release_dates_path
:
Optional
[
str
],
obsolete_pdbs_path
:
Optional
[
str
],
obsolete_pdbs_path
:
Optional
[
str
],
strict_error_check
:
bool
=
False
):
strict_error_check
:
bool
=
False
,
):
"""Initializes the Template Search.
"""Initializes the Template Search.
Args:
Args:
...
@@ -802,28 +926,34 @@ class TemplateHitFeaturizer:
...
@@ -802,28 +926,34 @@ class TemplateHitFeaturizer:
* Any feature computation errors.
* Any feature computation errors.
"""
"""
self
.
_mmcif_dir
=
mmcif_dir
self
.
_mmcif_dir
=
mmcif_dir
if
not
glob
.
glob
(
os
.
path
.
join
(
self
.
_mmcif_dir
,
'
*.cif
'
)):
if
not
glob
.
glob
(
os
.
path
.
join
(
self
.
_mmcif_dir
,
"
*.cif
"
)):
logging
.
error
(
'
Could not find CIFs in %s
'
,
self
.
_mmcif_dir
)
logging
.
error
(
"
Could not find CIFs in %s
"
,
self
.
_mmcif_dir
)
raise
ValueError
(
f
'
Could not find CIFs in
{
self
.
_mmcif_dir
}
'
)
raise
ValueError
(
f
"
Could not find CIFs in
{
self
.
_mmcif_dir
}
"
)
try
:
try
:
self
.
_max_template_date
=
datetime
.
datetime
.
strptime
(
self
.
_max_template_date
=
datetime
.
datetime
.
strptime
(
max_template_date
,
'%Y-%m-%d'
)
max_template_date
,
"%Y-%m-%d"
)
except
ValueError
:
except
ValueError
:
raise
ValueError
(
raise
ValueError
(
'max_template_date must be set and have format YYYY-MM-DD.'
)
"max_template_date must be set and have format YYYY-MM-DD."
)
self
.
_max_hits
=
max_hits
self
.
_max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
self
.
_strict_error_check
=
strict_error_check
if
release_dates_path
:
if
release_dates_path
:
logging
.
info
(
'Using precomputed release dates %s.'
,
release_dates_path
)
logging
.
info
(
"Using precomputed release dates %s."
,
release_dates_path
)
self
.
_release_dates
=
_parse_release_dates
(
release_dates_path
)
self
.
_release_dates
=
_parse_release_dates
(
release_dates_path
)
else
:
else
:
self
.
_release_dates
=
{}
self
.
_release_dates
=
{}
if
obsolete_pdbs_path
:
if
obsolete_pdbs_path
:
logging
.
info
(
'Using precomputed obsolete pdbs %s.'
,
obsolete_pdbs_path
)
logging
.
info
(
"Using precomputed obsolete pdbs %s."
,
obsolete_pdbs_path
)
self
.
_obsolete_pdbs
=
_parse_obsolete
(
obsolete_pdbs_path
)
self
.
_obsolete_pdbs
=
_parse_obsolete
(
obsolete_pdbs_path
)
else
:
else
:
self
.
_obsolete_pdbs
=
{}
self
.
_obsolete_pdbs
=
{}
...
@@ -833,9 +963,10 @@ class TemplateHitFeaturizer:
...
@@ -833,9 +963,10 @@ class TemplateHitFeaturizer:
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
])
->
TemplateSearchResult
:
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
'
Searching for template for: %s
'
,
query_pdb_code
)
logging
.
info
(
"
Searching for template for: %s
"
,
query_pdb_code
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
...
@@ -869,7 +1000,8 @@ class TemplateHitFeaturizer:
...
@@ -869,7 +1000,8 @@ class TemplateHitFeaturizer:
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
kalign_binary_path
=
self
.
_kalign_binary_path
,
)
if
result
.
error
:
if
result
.
error
:
errors
.
append
(
result
.
error
)
errors
.
append
(
result
.
error
)
...
@@ -880,8 +1012,12 @@ class TemplateHitFeaturizer:
...
@@ -880,8 +1012,12 @@ class TemplateHitFeaturizer:
warnings
.
append
(
result
.
warning
)
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
if
result
.
features
is
None
:
logging
.
info
(
'Skipped invalid hit %s, error: %s, warning: %s'
,
logging
.
info
(
hit
.
name
,
result
.
error
,
result
.
warning
)
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
else
:
# Increment the hit counter, since we got features out of this hit.
# Increment the hit counter, since we got features out of this hit.
num_hits
+=
1
num_hits
+=
1
...
@@ -891,10 +1027,14 @@ class TemplateHitFeaturizer:
...
@@ -891,10 +1027,14 @@ class TemplateHitFeaturizer:
for
name
in
template_features
:
for
name
in
template_features
:
if
num_hits
>
0
:
if
num_hits
>
0
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
else
:
# Make sure the feature has correct dtype even if empty.
# Make sure the feature has correct dtype even if empty.
template_features
[
name
]
=
np
.
array
([],
dtype
=
TEMPLATE_FEATURES
[
name
])
template_features
[
name
]
=
np
.
array
(
[],
dtype
=
TEMPLATE_FEATURES
[
name
]
)
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
openfold/data/tools/hhblits.py
View file @
07e64267
...
@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
...
@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
class
HHBlits
:
class
HHBlits
:
"""Python wrapper of the HHblits binary."""
"""Python wrapper of the HHblits binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
binary_path
:
str
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
databases
:
Sequence
[
str
],
...
@@ -44,7 +45,8 @@ class HHBlits:
...
@@ -44,7 +45,8 @@ class HHBlits:
all_seqs
:
bool
=
False
,
all_seqs
:
bool
=
False
,
alt
:
Optional
[
int
]
=
None
,
alt
:
Optional
[
int
]
=
None
,
p
:
int
=
_HHBLITS_DEFAULT_P
,
p
:
int
=
_HHBLITS_DEFAULT_P
,
z
:
int
=
_HHBLITS_DEFAULT_Z
):
z
:
int
=
_HHBLITS_DEFAULT_Z
,
):
"""Initializes the Python HHblits wrapper.
"""Initializes the Python HHblits wrapper.
Args:
Args:
...
@@ -77,9 +79,13 @@ class HHBlits:
...
@@ -77,9 +79,13 @@ class HHBlits:
self
.
databases
=
databases
self
.
databases
=
databases
for
database_path
in
self
.
databases
:
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
if
not
glob
.
glob
(
database_path
+
"_*"
):
logging
.
error
(
'Could not find HHBlits database %s'
,
database_path
)
logging
.
error
(
raise
ValueError
(
f
'Could not find HHBlits database
{
database_path
}
'
)
"Could not find HHBlits database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find HHBlits database
{
database_path
}
"
)
self
.
n_cpu
=
n_cpu
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
n_iter
=
n_iter
...
@@ -95,52 +101,66 @@ class HHBlits:
...
@@ -95,52 +101,66 @@ class HHBlits:
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
"""Queries the database using HHblits."""
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.a3m
'
)
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.a3m
"
)
db_cmd
=
[]
db_cmd
=
[]
for
db_path
in
self
.
databases
:
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'
-d
'
)
db_cmd
.
append
(
"
-d
"
)
db_cmd
.
append
(
db_path
)
db_cmd
.
append
(
db_path
)
cmd
=
[
cmd
=
[
self
.
binary_path
,
self
.
binary_path
,
'-i'
,
input_fasta_path
,
"-i"
,
'-cpu'
,
str
(
self
.
n_cpu
),
input_fasta_path
,
'-oa3m'
,
a3m_path
,
"-cpu"
,
'-o'
,
'/dev/null'
,
str
(
self
.
n_cpu
),
'-n'
,
str
(
self
.
n_iter
),
"-oa3m"
,
'-e'
,
str
(
self
.
e_value
),
a3m_path
,
'-maxseq'
,
str
(
self
.
maxseq
),
"-o"
,
'-realign_max'
,
str
(
self
.
realign_max
),
"/dev/null"
,
'-maxfilt'
,
str
(
self
.
maxfilt
),
"-n"
,
'-min_prefilter_hits'
,
str
(
self
.
min_prefilter_hits
)]
str
(
self
.
n_iter
),
"-e"
,
str
(
self
.
e_value
),
"-maxseq"
,
str
(
self
.
maxseq
),
"-realign_max"
,
str
(
self
.
realign_max
),
"-maxfilt"
,
str
(
self
.
maxfilt
),
"-min_prefilter_hits"
,
str
(
self
.
min_prefilter_hits
),
]
if
self
.
all_seqs
:
if
self
.
all_seqs
:
cmd
+=
[
'
-all
'
]
cmd
+=
[
"
-all
"
]
if
self
.
alt
:
if
self
.
alt
:
cmd
+=
[
'
-alt
'
,
str
(
self
.
alt
)]
cmd
+=
[
"
-alt
"
,
str
(
self
.
alt
)]
if
self
.
p
!=
_HHBLITS_DEFAULT_P
:
if
self
.
p
!=
_HHBLITS_DEFAULT_P
:
cmd
+=
[
'
-p
'
,
str
(
self
.
p
)]
cmd
+=
[
"
-p
"
,
str
(
self
.
p
)]
if
self
.
z
!=
_HHBLITS_DEFAULT_Z
:
if
self
.
z
!=
_HHBLITS_DEFAULT_Z
:
cmd
+=
[
'
-Z
'
,
str
(
self
.
z
)]
cmd
+=
[
"
-Z
"
,
str
(
self
.
z
)]
cmd
+=
db_cmd
cmd
+=
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'
HHblits query
'
):
with
utils
.
timing
(
"
HHblits query
"
):
stdout
,
stderr
=
process
.
communicate
()
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
retcode
=
process
.
wait
()
if
retcode
:
if
retcode
:
# Logs have a 15k character limit, so log HHblits error line by line.
# Logs have a 15k character limit, so log HHblits error line by line.
logging
.
error
(
'
HHblits failed. HHblits stderr begin:
'
)
logging
.
error
(
"
HHblits failed. HHblits stderr begin:
"
)
for
error_line
in
stderr
.
decode
(
'
utf-8
'
).
splitlines
():
for
error_line
in
stderr
.
decode
(
"
utf-8
"
).
splitlines
():
if
error_line
.
strip
():
if
error_line
.
strip
():
logging
.
error
(
error_line
.
strip
())
logging
.
error
(
error_line
.
strip
())
logging
.
error
(
'HHblits stderr end'
)
logging
.
error
(
"HHblits stderr end"
)
raise
RuntimeError
(
'HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
raise
RuntimeError
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
500_000
].
decode
(
'utf-8'
)))
"HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
[:
500_000
].
decode
(
"utf-8"
))
)
with
open
(
a3m_path
)
as
f
:
with
open
(
a3m_path
)
as
f
:
a3m
=
f
.
read
()
a3m
=
f
.
read
()
...
@@ -150,5 +170,6 @@ class HHBlits:
...
@@ -150,5 +170,6 @@ class HHBlits:
output
=
stdout
,
output
=
stdout
,
stderr
=
stderr
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
e_value
=
self
.
e_value
,
)
return
raw_output
return
raw_output
openfold/data/tools/hhsearch.py
View file @
07e64267
...
@@ -26,12 +26,14 @@ from openfold.data.np import utils
...
@@ -26,12 +26,14 @@ from openfold.data.np import utils
class
HHSearch
:
class
HHSearch
:
"""Python wrapper of the HHsearch binary."""
"""Python wrapper of the HHsearch binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
binary_path
:
str
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
databases
:
Sequence
[
str
],
n_cpu
:
int
=
2
,
n_cpu
:
int
=
2
,
maxseq
:
int
=
1_000_000
):
maxseq
:
int
=
1_000_000
,
):
"""Initializes the Python HHsearch wrapper.
"""Initializes the Python HHsearch wrapper.
Args:
Args:
...
@@ -52,41 +54,52 @@ class HHSearch:
...
@@ -52,41 +54,52 @@ class HHSearch:
self
.
maxseq
=
maxseq
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
if
not
glob
.
glob
(
database_path
+
"_*"
):
logging
.
error
(
'Could not find HHsearch database %s'
,
database_path
)
logging
.
error
(
raise
ValueError
(
f
'Could not find HHsearch database
{
database_path
}
'
)
"Could not find HHsearch database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find HHsearch database
{
database_path
}
"
)
def
query
(
self
,
a3m
:
str
)
->
str
:
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
query.a3m
'
)
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
query.a3m
"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.hhr
'
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.hhr
"
)
with
open
(
input_path
,
'w'
)
as
f
:
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
f
.
write
(
a3m
)
db_cmd
=
[]
db_cmd
=
[]
for
db_path
in
self
.
databases
:
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'
-d
'
)
db_cmd
.
append
(
"
-d
"
)
db_cmd
.
append
(
db_path
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
cmd
=
[
'-i'
,
input_path
,
self
.
binary_path
,
'-o'
,
hhr_path
,
"-i"
,
'-maxseq'
,
str
(
self
.
maxseq
),
input_path
,
'-cpu'
,
str
(
self
.
n_cpu
),
"-o"
,
hhr_path
,
"-maxseq"
,
str
(
self
.
maxseq
),
"-cpu"
,
str
(
self
.
n_cpu
),
]
+
db_cmd
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
with
utils
.
timing
(
'HHsearch query'
):
)
with
utils
.
timing
(
"HHsearch query"
):
stdout
,
stderr
=
process
.
communicate
()
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
retcode
=
process
.
wait
()
if
retcode
:
if
retcode
:
# Stderr is truncated to prevent proto size errors in Beam.
# Stderr is truncated to prevent proto size errors in Beam.
raise
RuntimeError
(
raise
RuntimeError
(
'HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
"HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
stdout
.
decode
(
'utf-8'
),
stderr
[:
100_000
].
decode
(
'utf-8'
)))
%
(
stdout
.
decode
(
"utf-8"
),
stderr
[:
100_000
].
decode
(
"utf-8"
))
)
with
open
(
hhr_path
)
as
f
:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
hhr
=
f
.
read
()
...
...
openfold/data/tools/jackhmmer.py
View file @
07e64267
...
@@ -29,7 +29,8 @@ from openfold.data.tools import utils
...
@@ -29,7 +29,8 @@ from openfold.data.tools import utils
class
Jackhmmer
:
class
Jackhmmer
:
"""Python wrapper of the Jackhmmer binary."""
"""Python wrapper of the Jackhmmer binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
binary_path
:
str
,
binary_path
:
str
,
database_path
:
str
,
database_path
:
str
,
...
@@ -44,7 +45,8 @@ class Jackhmmer:
...
@@ -44,7 +45,8 @@ class Jackhmmer:
incdom_e
:
Optional
[
float
]
=
None
,
incdom_e
:
Optional
[
float
]
=
None
,
dom_e
:
Optional
[
float
]
=
None
,
dom_e
:
Optional
[
float
]
=
None
,
num_streamed_chunks
:
Optional
[
int
]
=
None
,
num_streamed_chunks
:
Optional
[
int
]
=
None
,
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
):
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
,
):
"""Initializes the Python Jackhmmer wrapper.
"""Initializes the Python Jackhmmer wrapper.
Args:
Args:
...
@@ -69,9 +71,14 @@ class Jackhmmer:
...
@@ -69,9 +71,14 @@ class Jackhmmer:
self
.
database_path
=
database_path
self
.
database_path
=
database_path
self
.
num_streamed_chunks
=
num_streamed_chunks
self
.
num_streamed_chunks
=
num_streamed_chunks
if
not
os
.
path
.
exists
(
self
.
database_path
)
and
num_streamed_chunks
is
None
:
if
(
logging
.
error
(
'Could not find Jackhmmer database %s'
,
database_path
)
not
os
.
path
.
exists
(
self
.
database_path
)
raise
ValueError
(
f
'Could not find Jackhmmer database
{
database_path
}
'
)
and
num_streamed_chunks
is
None
):
logging
.
error
(
"Could not find Jackhmmer database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find Jackhmmer database
{
database_path
}
"
)
self
.
n_cpu
=
n_cpu
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
n_iter
=
n_iter
...
@@ -85,11 +92,12 @@ class Jackhmmer:
...
@@ -85,11 +92,12 @@ class Jackhmmer:
self
.
get_tblout
=
get_tblout
self
.
get_tblout
=
get_tblout
self
.
streaming_callback
=
streaming_callback
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.sto
'
)
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.sto
"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# stages (which get progressively more expensive), reducing these
...
@@ -98,48 +106,63 @@ class Jackhmmer:
...
@@ -98,48 +106,63 @@ class Jackhmmer:
# amount of time.
# amount of time.
cmd_flags
=
[
cmd_flags
=
[
# Don't pollute stdout with Jackhmmer output.
# Don't pollute stdout with Jackhmmer output.
'-o'
,
'/dev/null'
,
"-o"
,
'-A'
,
sto_path
,
"/dev/null"
,
'--noali'
,
"-A"
,
'--F1'
,
str
(
self
.
filter_f1
),
sto_path
,
'--F2'
,
str
(
self
.
filter_f2
),
"--noali"
,
'--F3'
,
str
(
self
.
filter_f3
),
"--F1"
,
'--incE'
,
str
(
self
.
e_value
),
str
(
self
.
filter_f1
),
"--F2"
,
str
(
self
.
filter_f2
),
"--F3"
,
str
(
self
.
filter_f3
),
"--incE"
,
str
(
self
.
e_value
),
# Report only sequences with E-values <= x in per-sequence output.
# Report only sequences with E-values <= x in per-sequence output.
'-E'
,
str
(
self
.
e_value
),
"-E"
,
'--cpu'
,
str
(
self
.
n_cpu
),
str
(
self
.
e_value
),
'-N'
,
str
(
self
.
n_iter
)
"--cpu"
,
str
(
self
.
n_cpu
),
"-N"
,
str
(
self
.
n_iter
),
]
]
if
self
.
get_tblout
:
if
self
.
get_tblout
:
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
tblout.txt
'
)
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
tblout.txt
"
)
cmd_flags
.
extend
([
'
--tblout
'
,
tblout_path
])
cmd_flags
.
extend
([
"
--tblout
"
,
tblout_path
])
if
self
.
z_value
:
if
self
.
z_value
:
cmd_flags
.
extend
([
'
-Z
'
,
str
(
self
.
z_value
)])
cmd_flags
.
extend
([
"
-Z
"
,
str
(
self
.
z_value
)])
if
self
.
dom_e
is
not
None
:
if
self
.
dom_e
is
not
None
:
cmd_flags
.
extend
([
'
--domE
'
,
str
(
self
.
dom_e
)])
cmd_flags
.
extend
([
"
--domE
"
,
str
(
self
.
dom_e
)])
if
self
.
incdom_e
is
not
None
:
if
self
.
incdom_e
is
not
None
:
cmd_flags
.
extend
([
'
--incdomE
'
,
str
(
self
.
incdom_e
)])
cmd_flags
.
extend
([
"
--incdomE
"
,
str
(
self
.
incdom_e
)])
cmd
=
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
cmd
=
(
database_path
]
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
database_path
]
)
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
with
utils
.
timing
(
f
'Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query'
):
f
"Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query"
):
_
,
stderr
=
process
.
communicate
()
_
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
retcode
=
process
.
wait
()
if
retcode
:
if
retcode
:
raise
RuntimeError
(
raise
RuntimeError
(
'Jackhmmer failed
\n
stderr:
\n
%s
\n
'
%
stderr
.
decode
(
'utf-8'
))
"Jackhmmer failed
\n
stderr:
\n
%s
\n
"
%
stderr
.
decode
(
"utf-8"
)
)
# Get e-values for each target name
# Get e-values for each target name
tbl
=
''
tbl
=
""
if
self
.
get_tblout
:
if
self
.
get_tblout
:
with
open
(
tblout_path
)
as
f
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
tbl
=
f
.
read
()
...
@@ -152,7 +175,8 @@ class Jackhmmer:
...
@@ -152,7 +175,8 @@ class Jackhmmer:
tbl
=
tbl
,
tbl
=
tbl
,
stderr
=
stderr
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
e_value
=
self
.
e_value
,
)
return
raw_output
return
raw_output
...
@@ -162,15 +186,15 @@ class Jackhmmer:
...
@@ -162,15 +186,15 @@ class Jackhmmer:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
'
{
self
.
database_path
}
.
{
db_idx
}
'
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
db_local_chunk
=
lambda
db_idx
:
f
'
/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
'
db_local_chunk
=
lambda
db_idx
:
f
"
/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
"
# Remove existing files to prevent OOM
# Remove existing files to prevent OOM
for
f
in
glob
.
glob
(
db_local_chunk
(
'
[0-9]*
'
)):
for
f
in
glob
.
glob
(
db_local_chunk
(
"
[0-9]*
"
)):
try
:
try
:
os
.
remove
(
f
)
os
.
remove
(
f
)
except
OSError
:
except
OSError
:
print
(
f
'
OSError while deleting
{
f
}
'
)
print
(
f
"
OSError while deleting
{
f
}
"
)
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
...
@@ -179,15 +203,22 @@ class Jackhmmer:
...
@@ -179,15 +203,22 @@ class Jackhmmer:
# Copy the chunk locally
# Copy the chunk locally
if
i
==
1
:
if
i
==
1
:
future
=
executor
.
submit
(
future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
))
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
),
)
if
i
<
self
.
num_streamed_chunks
:
if
i
<
self
.
num_streamed_chunks
:
next_future
=
executor
.
submit
(
next_future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
))
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
),
)
# Run Jackhmmer with the chunk
# Run Jackhmmer with the chunk
future
.
result
()
future
.
result
()
chunked_output
.
append
(
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
)))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
)
# Remove the local copy of the chunk
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
os
.
remove
(
db_local_chunk
(
i
))
...
...
openfold/data/tools/kalign.py
View file @
07e64267
...
@@ -25,12 +25,12 @@ from openfold.data.tools import utils
...
@@ -25,12 +25,12 @@ from openfold.data.tools import utils
def
_to_a3m
(
sequences
:
Sequence
[
str
])
->
str
:
def
_to_a3m
(
sequences
:
Sequence
[
str
])
->
str
:
"""Converts sequences to an a3m file."""
"""Converts sequences to an a3m file."""
names
=
[
'
sequence %d
'
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
names
=
[
"
sequence %d
"
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
a3m
=
[]
a3m
=
[]
for
sequence
,
name
in
zip
(
sequences
,
names
):
for
sequence
,
name
in
zip
(
sequences
,
names
):
a3m
.
append
(
u
'>'
+
name
+
u
'
\n
'
)
a3m
.
append
(
u
">"
+
name
+
u
"
\n
"
)
a3m
.
append
(
sequence
+
u
'
\n
'
)
a3m
.
append
(
sequence
+
u
"
\n
"
)
return
''
.
join
(
a3m
)
return
""
.
join
(
a3m
)
class
Kalign
:
class
Kalign
:
...
@@ -63,40 +63,51 @@ class Kalign:
...
@@ -63,40 +63,51 @@ class Kalign:
RuntimeError: If Kalign fails.
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
ValueError: If any of the sequences is less than 6 residues long.
"""
"""
logging
.
info
(
'
Aligning %d sequences
'
,
len
(
sequences
))
logging
.
info
(
"
Aligning %d sequences
"
,
len
(
sequences
))
for
s
in
sequences
:
for
s
in
sequences
:
if
len
(
s
)
<
6
:
if
len
(
s
)
<
6
:
raise
ValueError
(
'Kalign requires all sequences to be at least 6 '
raise
ValueError
(
'residues long. Got %s (%d residues).'
%
(
s
,
len
(
s
)))
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
input.fasta
'
)
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
input.fasta
"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.a3m
'
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.a3m
"
)
with
open
(
input_fasta_path
,
'w'
)
as
f
:
with
open
(
input_fasta_path
,
"w"
)
as
f
:
f
.
write
(
_to_a3m
(
sequences
))
f
.
write
(
_to_a3m
(
sequences
))
cmd
=
[
cmd
=
[
self
.
binary_path
,
self
.
binary_path
,
'-i'
,
input_fasta_path
,
"-i"
,
'-o'
,
output_a3m_path
,
input_fasta_path
,
'-format'
,
'fasta'
,
"-o"
,
output_a3m_path
,
"-format"
,
"fasta"
,
]
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
process
=
subprocess
.
Popen
(
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'
Kalign query
'
):
with
utils
.
timing
(
"
Kalign query
"
):
stdout
,
stderr
=
process
.
communicate
()
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
retcode
=
process
.
wait
()
logging
.
info
(
'Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
logging
.
info
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
"Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
,
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
),
)
if
retcode
:
if
retcode
:
raise
RuntimeError
(
'Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
raise
RuntimeError
(
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
"Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
))
)
with
open
(
output_a3m_path
)
as
f
:
with
open
(
output_a3m_path
)
as
f
:
a3m
=
f
.
read
()
a3m
=
f
.
read
()
...
...
openfold/data/tools/utils.py
View file @
07e64267
...
@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
...
@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
timing
(
msg
:
str
):
def
timing
(
msg
:
str
):
logging
.
info
(
'
Started %s
'
,
msg
)
logging
.
info
(
"
Started %s
"
,
msg
)
tic
=
time
.
time
()
tic
=
time
.
time
()
yield
yield
toc
=
time
.
time
()
toc
=
time
.
time
()
logging
.
info
(
'
Finished %s in %.3f seconds
'
,
msg
,
toc
-
tic
)
logging
.
info
(
"
Finished %s in %.3f seconds
"
,
msg
,
toc
-
tic
)
def
to_date
(
s
:
str
):
def
to_date
(
s
:
str
):
...
...
openfold/model/__init__.py
View file @
07e64267
...
@@ -3,13 +3,14 @@ import glob
...
@@ -3,13 +3,14 @@ import glob
import
importlib
as
importlib
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)]
__all__
=
[
_modules
=
[(
m
,
importlib
.
import_module
(
'.'
+
m
,
__name__
))
for
m
in
__all__
]
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
del
_files
,
_m
,
_modules
openfold/model/dropout.py
View file @
07e64267
...
@@ -26,6 +26,7 @@ class Dropout(nn.Module):
...
@@ -26,6 +26,7 @@ class Dropout(nn.Module):
If not in training mode, this module computes the identity function.
If not in training mode, this module computes the identity function.
"""
"""
def
__init__
(
self
,
r
:
float
,
batch_dim
:
Union
[
int
,
List
[
int
]]):
def
__init__
(
self
,
r
:
float
,
batch_dim
:
Union
[
int
,
List
[
int
]]):
"""
"""
Args:
Args:
...
@@ -37,7 +38,7 @@ class Dropout(nn.Module):
...
@@ -37,7 +38,7 @@ class Dropout(nn.Module):
super
(
Dropout
,
self
).
__init__
()
super
(
Dropout
,
self
).
__init__
()
self
.
r
=
r
self
.
r
=
r
if
(
type
(
batch_dim
)
==
int
)
:
if
type
(
batch_dim
)
==
int
:
batch_dim
=
[
batch_dim
]
batch_dim
=
[
batch_dim
]
self
.
batch_dim
=
batch_dim
self
.
batch_dim
=
batch_dim
self
.
dropout
=
nn
.
Dropout
(
self
.
r
)
self
.
dropout
=
nn
.
Dropout
(
self
.
r
)
...
@@ -50,7 +51,7 @@ class Dropout(nn.Module):
...
@@ -50,7 +51,7 @@ class Dropout(nn.Module):
compatible with self.batch_dim
compatible with self.batch_dim
"""
"""
shape
=
list
(
x
.
shape
)
shape
=
list
(
x
.
shape
)
if
(
self
.
batch_dim
is
not
None
)
:
if
self
.
batch_dim
is
not
None
:
for
bd
in
self
.
batch_dim
:
for
bd
in
self
.
batch_dim
:
shape
[
bd
]
=
1
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
x
.
new_ones
(
shape
)
...
@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
...
@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
Convenience class for rowwise dropout as described in subsection
Convenience class for rowwise dropout as described in subsection
1.11.6.
1.11.6.
"""
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
3
)
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
3
)
...
@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
...
@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
Convenience class for columnwise dropout as described in subsection
Convenience class for columnwise dropout as described in subsection
1.11.6.
1.11.6.
"""
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
2
)
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
2
)
openfold/model/embedders.py
View file @
07e64267
...
@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
...
@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
tf_dim
:
int
,
tf_dim
:
int
,
...
@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
...
@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
self
.
no_bins
=
2
*
relpos_k
+
1
self
.
no_bins
=
2
*
relpos_k
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
def
relpos
(
self
,
ri
:
torch
.
Tensor
):
ri
:
torch
.
Tensor
):
"""
"""
Computes relative positional encodings
Computes relative positional encodings
...
@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
...
@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
self
,
def
forward
(
self
,
tf
:
torch
.
Tensor
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
...
@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
...
@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
Implements Algorithm 32.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
min_bin
:
float
,
min_bin
:
float
,
max_bin
:
float
,
max_bin
:
float
,
no_bins
:
int
,
no_bins
:
int
,
inf
:
float
=
1e8
,
inf
:
float
=
1e8
,
**
kwargs
**
kwargs
,
):
):
"""
"""
Args:
Args:
...
@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
...
@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
...
@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
z:
z:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
if
(
self
.
bins
is
None
)
:
if
self
.
bins
is
None
:
self
.
bins
=
torch
.
linspace
(
self
.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
min_bin
,
self
.
max_bin
,
self
.
max_bin
,
self
.
no_bins
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
dtype
=
x
.
dtype
,
device
=
x
.
device
device
=
x
.
device
,
)
)
# [*, N, C_m]
# [*, N, C_m]
...
@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
...
@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
# couldn't find in time.
# couldn't find in time.
squared_bins
=
self
.
bins
**
2
squared_bins
=
self
.
bins
**
2
upper
=
torch
.
cat
(
upper
=
torch
.
cat
(
[
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])
],
dim
=-
1
)
)
d
=
torch
.
sum
(
d
=
torch
.
sum
(
(
x
[...,
None
,
:]
-
x
[...,
None
,
:,
:])
**
2
,
(
x
[...,
None
,
:]
-
x
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
dim
=-
1
,
keepdims
=
True
)
)
# [*, N, N, no_bins]
# [*, N, N, no_bins]
...
@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
...
@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
Implements Algorithm 2, line 7.
Implements Algorithm 2, line 7.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_in
:
int
,
c_out
:
int
,
c_out
:
int
,
**
kwargs
,
**
kwargs
,
...
@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
...
@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
c_out
,
self
.
c_out
,
init
=
"relu"
)
self
.
linear_2
=
Linear
(
self
.
c_out
,
self
.
c_out
,
init
=
"relu"
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
...
@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
...
@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
Implements Algorithm 2, line 9.
Implements Algorithm 2, line 9.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_in
:
int
,
c_out
:
int
,
c_out
:
int
,
**
kwargs
,
**
kwargs
,
...
@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
...
@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer
# Despite there being no relu nearby, the source uses that initializer
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
,
init
=
"relu"
)
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
,
init
=
"relu"
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
...
@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
Implements Algorithm 2, line 15
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_in
:
int
,
c_out
:
int
,
c_out
:
int
,
**
kwargs
,
**
kwargs
,
...
@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
...
@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
)
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
x:
x:
...
...
openfold/model/evoformer.py
View file @
07e64267
...
@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
...
@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
Implements Algorithm 9
"""
"""
def
__init__
(
self
,
c_m
,
n
,
chunk_size
):
def
__init__
(
self
,
c_m
,
n
,
chunk_size
):
"""
"""
Args:
Args:
...
@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
...
@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
m
=
self
.
linear_2
(
m
)
*
mask
m
=
self
.
linear_2
(
m
)
*
mask
return
m
return
m
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
mask
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
...
@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
[*, N_seq, N_res, C_m] MSA activation update
[*, N_seq, N_res, C_m] MSA activation update
"""
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if
(
mask
is
None
)
:
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
...
@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
...
@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
...
@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
...
@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_msa_att
:
int
,
...
@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
...
@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
inf
=
inf
,
)
)
if
(
_is_extra_msa_stack
)
:
if
_is_extra_msa_stack
:
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
c_hidden
=
c_hidden_msa_att
,
...
@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
...
@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
...
@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
Implements Algorithm 6.
Implements Algorithm 6.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_msa_att
:
int
,
...
@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
...
@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
if
(
not
self
.
_is_extra_msa_stack
)
:
if
not
self
.
_is_extra_msa_stack
:
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
...
@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
)
for
b
in
self
.
blocks
],
],
args
=
(
m
,
z
),
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
s
=
None
s
=
None
if
(
not
self
.
_is_extra_msa_stack
)
:
if
not
self
.
_is_extra_msa_stack
:
seq_dim
=
-
3
seq_dim
=
-
3
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
...
@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
...
@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
"""
"""
Implements Algorithm 18.
Implements Algorithm 18.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_msa_att
:
int
,
...
@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
...
@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
_is_extra_msa_stack
=
True
,
_is_extra_msa_stack
=
True
,
)
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
_mask_trans
=
_mask_trans
,
)
)
return
z
return
z
openfold/model/heads.py
View file @
07e64267
...
@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
...
@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**
config
[
"experimentally_resolved"
],
**
config
[
"experimentally_resolved"
],
)
)
if
(
config
.
tm
.
enabled
)
:
if
config
.
tm
.
enabled
:
self
.
tm
=
TMScoreHead
(
self
.
tm
=
TMScoreHead
(
**
config
.
tm
,
**
config
.
tm
,
)
)
...
@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
...
@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits
=
self
.
experimentally_resolved
(
experimentally_resolved_logits
=
self
.
experimentally_resolved
(
outputs
[
"single"
]
outputs
[
"single"
]
)
)
aux_out
[
"experimentally_resolved_logits"
]
=
(
aux_out
[
experimentally_resolved_logits
"
experimentally_resolved_logits
"
)
]
=
experimentally_resolved_logits
if
(
self
.
config
.
tm
.
enabled
)
:
if
self
.
config
.
tm
.
enabled
:
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"predicted_tm_score"
]
=
compute_tm
(
aux_out
[
"predicted_tm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
tm_logits
,
**
self
.
config
.
tm
)
)
aux_out
.
update
(
compute_predicted_aligned_error
(
aux_out
.
update
(
tm_logits
,
**
self
.
config
.
tm
,
compute_predicted_aligned_error
(
))
tm_logits
,
**
self
.
config
.
tm
,
)
)
return
aux_out
return
aux_out
...
@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
...
@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
For use in computation of distogram loss, subsection 1.9.8
For use in computation of distogram loss, subsection 1.9.8
"""
"""
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
...
@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
def
forward
(
self
,
z
):
# [*, N, N, C_z]
z
# [*, N, N, C_z]
):
"""
"""
Args:
Args:
z:
z:
...
@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
...
@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
"""
"""
For use in computation of TM-score, subsection 1.9.7
For use in computation of TM-score, subsection 1.9.7
"""
"""
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
...
@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
"""
"""
For use in computation of masked MSA loss, subsection 1.9.9
For use in computation of masked MSA loss, subsection 1.9.9
"""
"""
def
__init__
(
self
,
c_m
,
c_out
,
**
kwargs
):
def
__init__
(
self
,
c_m
,
c_out
,
**
kwargs
):
"""
"""
Args:
Args:
...
@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
...
@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
For use in computation of "experimentally resolved" loss, subsection
For use in computation of "experimentally resolved" loss, subsection
1.9.10
1.9.10
"""
"""
def
__init__
(
self
,
c_s
,
c_out
,
**
kwargs
):
def
__init__
(
self
,
c_s
,
c_out
,
**
kwargs
):
"""
"""
Args:
Args:
...
...
openfold/model/model.py
View file @
07e64267
...
@@ -54,6 +54,7 @@ class AlphaFold(nn.Module):
...
@@ -54,6 +54,7 @@ class AlphaFold(nn.Module):
Implements Algorithm 2 (but with training).
Implements Algorithm 2 (but with training).
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
Args:
Args:
...
@@ -115,7 +116,7 @@ class AlphaFold(nn.Module):
...
@@ -115,7 +116,7 @@ class AlphaFold(nn.Module):
)
)
single_template_embeds
=
{}
single_template_embeds
=
{}
if
(
self
.
config
.
template
.
embed_angles
)
:
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
single_template_feats
,
)
)
...
@@ -130,18 +131,18 @@ class AlphaFold(nn.Module):
...
@@ -130,18 +131,18 @@ class AlphaFold(nn.Module):
single_template_feats
,
single_template_feats
,
inf
=
self
.
config
.
template
.
inf
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
**
self
.
config
.
template
.
distogram
,
)
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_stack
(
t
=
self
.
template_pair_stack
(
t
,
t
,
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
)
)
single_template_embeds
.
update
({
single_template_embeds
.
update
(
{
"pair"
:
t
,
"pair"
:
t
,
})
}
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
.
append
(
single_template_embeds
)
...
@@ -152,19 +153,19 @@ class AlphaFold(nn.Module):
...
@@ -152,19 +153,19 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
=
self
.
template_pointwise_att
(
template_embeds
[
"pair"
],
template_embeds
[
"pair"
],
z
,
template_mask
=
batch
[
"template_mask"
]
z
,
template_mask
=
batch
[
"template_mask"
]
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
ret
=
{}
if
(
self
.
config
.
template
.
embed_angles
)
:
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
ret
.
update
(
{
"template_pair_embedding"
:
t
,
"template_pair_embedding"
:
t
,
})
}
)
return
ret
return
ret
...
@@ -195,9 +196,9 @@ class AlphaFold(nn.Module):
...
@@ -195,9 +196,9 @@ class AlphaFold(nn.Module):
)
)
# Inject information from previous recycling iterations
# Inject information from previous recycling iterations
if
(
self
.
config
.
num_recycle
>
0
)
:
if
self
.
config
.
num_recycle
>
0
:
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]
)
:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
...
@@ -213,11 +214,7 @@ class AlphaFold(nn.Module):
...
@@ -213,11 +214,7 @@ class AlphaFold(nn.Module):
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
)
)
x_prev
=
pseudo_beta_fn
(
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
feats
[
"aatype"
],
x_prev
,
None
)
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
...
@@ -237,9 +234,9 @@ class AlphaFold(nn.Module):
...
@@ -237,9 +234,9 @@ class AlphaFold(nn.Module):
del
m_1_prev_emb
,
z_prev_emb
del
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
)
:
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
template_embeds
=
self
.
embed_templates
(
template_embeds
=
self
.
embed_templates
(
template_feats
,
template_feats
,
...
@@ -251,11 +248,10 @@ class AlphaFold(nn.Module):
...
@@ -251,11 +248,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
(
self
.
config
.
template
.
embed_angles
)
:
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
[
m
,
template_embeds
[
"template_angle_embedding"
]],
dim
=-
3
dim
=-
3
)
)
# [*, S, N]
# [*, S, N]
...
@@ -265,7 +261,7 @@ class AlphaFold(nn.Module):
...
@@ -265,7 +261,7 @@ class AlphaFold(nn.Module):
)
)
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
(
self
.
config
.
extra_msa
.
enabled
)
:
if
self
.
config
.
extra_msa
.
enabled
:
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
...
@@ -287,7 +283,7 @@ class AlphaFold(nn.Module):
...
@@ -287,7 +283,7 @@ class AlphaFold(nn.Module):
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
self
.
config
.
_mask_trans
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
...
@@ -296,7 +292,10 @@ class AlphaFold(nn.Module):
...
@@ -296,7 +292,10 @@ class AlphaFold(nn.Module):
# Predict 3D structure
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
outputs
[
"sm"
]
=
self
.
structure_module
(
s
,
z
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
],
s
,
z
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
],
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
@@ -397,16 +396,19 @@ class AlphaFold(nn.Module):
...
@@ -397,16 +396,19 @@ class AlphaFold(nn.Module):
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
self
.
config
.
num_recycle
)
is_final_iter
=
cycle_no
==
self
.
config
.
num_recycle
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
)
:
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
self
.
_enable_activation_checkpointing
()
if
(
torch
.
is_autocast_enabled
()
)
:
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
)
# Run auxiliary heads
# Run auxiliary heads
...
...
openfold/model/msa.py
View file @
07e64267
...
@@ -27,7 +27,8 @@ from openfold.utils.tensor_utils import (
...
@@ -27,7 +27,8 @@ from openfold.utils.tensor_utils import (
class
MSAAttention
(
nn
.
Module
):
class
MSAAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
c_in
,
c_in
,
c_hidden
,
c_hidden
,
no_heads
,
no_heads
,
...
@@ -64,17 +65,14 @@ class MSAAttention(nn.Module):
...
@@ -64,17 +65,14 @@ class MSAAttention(nn.Module):
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
if
(
self
.
pair_bias
)
:
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
)
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
m
,
z
=
None
,
mask
=
None
):
def
forward
(
self
,
m
,
z
=
None
,
mask
=
None
):
...
@@ -92,7 +90,7 @@ class MSAAttention(nn.Module):
...
@@ -92,7 +90,7 @@ class MSAAttention(nn.Module):
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
(
mask
is
None
)
:
if
mask
is
None
:
# [*, N_seq, N_res]
# [*, N_seq, N_res]
mask
=
m
.
new_ones
(
mask
=
m
.
new_ones
(
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
...
@@ -106,7 +104,7 @@ class MSAAttention(nn.Module):
...
@@ -106,7 +104,7 @@ class MSAAttention(nn.Module):
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
)
)
biases
=
[
bias
]
biases
=
[
bias
]
if
(
self
.
pair_bias
)
:
if
self
.
pair_bias
:
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
...
@@ -118,18 +116,13 @@ class MSAAttention(nn.Module):
...
@@ -118,18 +116,13 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
biases
.
append
(
z
)
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
"q_x"
:
m
,
if
self
.
chunk_size
is
not
None
:
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
}
if
(
self
.
chunk_size
is
not
None
):
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
,
)
)
else
:
else
:
m
=
self
.
mha
(
**
mha_inputs
)
m
=
self
.
mha
(
**
mha_inputs
)
...
@@ -141,6 +134,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -141,6 +134,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
"""
"""
Implements Algorithm 7.
Implements Algorithm 7.
"""
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
=
1e9
):
"""
"""
Args:
Args:
...
@@ -170,6 +164,7 @@ class MSAColumnAttention(MSAAttention):
...
@@ -170,6 +164,7 @@ class MSAColumnAttention(MSAAttention):
"""
"""
Implements Algorithm 8.
Implements Algorithm 8.
"""
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
):
"""
"""
Args:
Args:
...
@@ -192,7 +187,6 @@ class MSAColumnAttention(MSAAttention):
...
@@ -192,7 +187,6 @@ class MSAColumnAttention(MSAAttention):
inf
=
inf
,
inf
=
inf
,
)
)
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
mask
=
None
):
"""
"""
Args:
Args:
...
@@ -203,26 +197,21 @@ class MSAColumnAttention(MSAAttention):
...
@@ -203,26 +197,21 @@ class MSAColumnAttention(MSAAttention):
"""
"""
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
(
mask
is
not
None
)
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
super
().
forward
(
m
,
mask
=
mask
)
m
=
super
().
forward
(
m
,
mask
=
mask
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
(
mask
is
not
None
)
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
return
m
return
m
class
MSAColumnGlobalAttention
(
nn
.
Module
):
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
c_in
,
self
,
c_in
,
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
,
eps
=
1e-10
c_hidden
,
no_heads
,
chunk_size
=
4
,
inf
=
1e9
,
eps
=
1e-10
):
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
...
@@ -243,13 +232,12 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -243,13 +232,12 @@ class MSAColumnGlobalAttention(nn.Module):
eps
=
eps
,
eps
=
eps
,
)
)
def
forward
(
self
,
def
forward
(
m
:
torch
.
Tensor
,
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
if
(
mask
is
None
)
:
if
mask
is
None
:
# [*, N_seq, N_res]
# [*, N_seq, N_res]
mask
=
torch
.
ones
(
mask
=
torch
.
ones
(
m
.
shape
[:
-
1
],
m
.
shape
[:
-
1
],
...
@@ -268,12 +256,12 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -268,12 +256,12 @@ class MSAColumnGlobalAttention(nn.Module):
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
global_attention
,
self
.
global_attention
,
mha_input
,
mha_input
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
,
)
)
else
:
else
:
m
=
self
.
global_attention
(
m
=
mha_input
[
"m"
],
mask
=
mha_input
[
"mask"
])
m
=
self
.
global_attention
(
m
=
mha_input
[
"m"
],
mask
=
mha_input
[
"mask"
])
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment