Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d8ee9c5f
Commit
d8ee9c5f
authored
Feb 17, 2023
by
Christina Floristean
Browse files
All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.
parent
260db67f
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
406 additions
and
278 deletions
+406
-278
.gitignore
.gitignore
+12
-0
openfold/config.py
openfold/config.py
+51
-0
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+3
-3
openfold/data/templates.py
openfold/data/templates.py
+21
-8
openfold/model/embedders.py
openfold/model/embedders.py
+23
-110
openfold/model/evoformer.py
openfold/model/evoformer.py
+78
-47
openfold/model/structure_module.py
openfold/model/structure_module.py
+53
-24
openfold/model/template.py
openfold/model/template.py
+47
-26
openfold/utils/feats.py
openfold/utils/feats.py
+15
-7
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+4
-1
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+5
-4
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+1
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+21
-20
openfold/utils/loss.py
openfold/utils/loss.py
+1
-3
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+10
-0
tests/compare_utils.py
tests/compare_utils.py
+4
-4
tests/config.py
tests/config.py
+4
-0
tests/data_utils.py
tests/data_utils.py
+27
-0
tests/test_data_pipeline.py
tests/test_data_pipeline.py
+25
-15
tests/test_data_transforms.py
tests/test_data_transforms.py
+1
-5
No files found.
.gitignore
0 → 100644
View file @
d8ee9c5f
.DS_Store
*.DS_Store
**/.DS_Store
.idea/
**/__pycache__
*.pyc
build/
dist/
*.egg-info/
openfold/resources
**/stereo_chemical_props.txt
**/sample_feats.pickle
openfold/config.py
View file @
d8ee9c5f
...
@@ -331,6 +331,7 @@ config = mlc.ConfigDict(
...
@@ -331,6 +331,7 @@ config = mlc.ConfigDict(
"no_heads"
:
4
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
"inf"
:
1e9
,
},
},
...
@@ -367,6 +368,7 @@ config = mlc.ConfigDict(
...
@@ -367,6 +368,7 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"clear_cache_between_blocks"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
...
@@ -388,6 +390,7 @@ config = mlc.ConfigDict(
...
@@ -388,6 +390,7 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
@@ -546,6 +549,7 @@ multimer_model_config_update = {
...
@@ -546,6 +549,7 @@ multimer_model_config_update = {
"no_heads"
:
4
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
"inf"
:
1e9
,
},
},
...
@@ -555,6 +559,53 @@ multimer_model_config_update = {
...
@@ -555,6 +559,53 @@ multimer_model_config_update = {
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
True
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"c_in"
:
25
,
"c_out"
:
c_e
,
},
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
8
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
"enabled"
:
True
,
},
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
32
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"c_s"
:
c_s
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
48
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
},
"heads"
:
{
"heads"
:
{
"lddt"
:
{
"lddt"
:
{
...
...
openfold/data/data_transforms.py
View file @
d8ee9c5f
...
@@ -93,7 +93,7 @@ def fix_templates_aatype(protein):
...
@@ -93,7 +93,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"
template_
aatype"
].
device
,
).
expand
(
num_templates
,
-
1
)
).
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
...
@@ -669,8 +669,8 @@ def make_atom14_masks(protein):
...
@@ -669,8 +669,8 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
batch
[
"aatype"
].
device
),
lambda
n
:
torch
.
tensor
(
n
,
device
=
"cpu"
),
batch
,
batch
,
np
.
ndarray
np
.
ndarray
)
)
out
=
make_atom14_masks
(
batch
)
out
=
make_atom14_masks
(
batch
)
...
...
openfold/data/templates.py
View file @
d8ee9c5f
...
@@ -1048,7 +1048,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1048,7 +1048,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for
i
in
idx
:
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
# We got all the templates we wanted, stop processing hits.
if
len
(
already_seen
)
>=
self
.
max_hits
:
if
len
(
already_seen
)
>=
self
.
_
max_hits
:
break
break
hit
=
filtered
[
i
]
hit
=
filtered
[
i
]
...
@@ -1088,16 +1088,29 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1088,16 +1088,29 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for
k
in
template_features
:
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
template_features
[
k
].
append
(
result
.
features
[
k
])
for
name
in
template_features
:
if
already_seen
:
if
num_hits
>
0
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
else
:
# Make sure the feature has correct dtype even if empty.
num_res
=
len
(
query_sequence
)
template_features
[
name
]
=
np
.
array
(
# Construct a default template with all zeros.
[],
dtype
=
TEMPLATE_FEATURES
[
name
]
template_features
=
{
)
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
...
...
openfold/model/embedders.py
View file @
d8ee9c5f
...
@@ -216,10 +216,12 @@ class InputEmbedderMultimer(nn.Module):
...
@@ -216,10 +216,12 @@ class InputEmbedderMultimer(nn.Module):
(
2
*
self
.
max_relative_idx
+
1
)
*
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
)
torch
.
ones_like
(
clipped_offset
)
)
)
boundaries
=
torch
.
arange
(
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
start
=
0
,
end
=
2
*
self
.
max_relative_idx
+
2
,
device
=
final_offset
.
device
)
rel_pos
=
one_hot
(
final_offset
,
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
boundaries
,
)
)
rel_feats
.
append
(
rel_pos
)
rel_feats
.
append
(
rel_pos
)
...
@@ -245,15 +247,21 @@ class InputEmbedderMultimer(nn.Module):
...
@@ -245,15 +247,21 @@ class InputEmbedderMultimer(nn.Module):
torch
.
ones_like
(
clipped_rel_chain
)
torch
.
ones_like
(
clipped_rel_chain
)
)
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
boundaries
=
torch
.
arange
(
start
=
0
,
end
=
2
*
max_rel_chain
+
2
,
device
=
final_rel_chain
.
device
)
rel_chain
=
one_hot
(
final_rel_chain
,
final_rel_chain
,
2
*
max_rel_chain
+
2
,
boundaries
,
)
)
rel_feats
.
append
(
rel_chain
)
rel_feats
.
append
(
rel_chain
)
else
:
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
boundaries
=
torch
.
arange
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
start
=
0
,
end
=
2
*
self
.
max_relative_idx
+
1
,
device
=
clipped_offset
.
device
)
rel_pos
=
one_hot
(
clipped_offset
,
boundaries
,
)
)
rel_feats
.
append
(
rel_pos
)
rel_feats
.
append
(
rel_pos
)
...
@@ -471,102 +479,6 @@ class TemplatePairEmbedder(nn.Module):
...
@@ -471,102 +479,6 @@ class TemplatePairEmbedder(nn.Module):
return
x
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
):
super
().
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_pair_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
ExtraMSAEmbedder
(
nn
.
Module
):
class
ExtraMSAEmbedder
(
nn
.
Module
):
"""
"""
Embeds unclustered MSA sequences.
Embeds unclustered MSA sequences.
...
@@ -625,12 +537,13 @@ class TemplateEmbedder(nn.Module):
...
@@ -625,12 +537,13 @@ class TemplateEmbedder(nn.Module):
**
config
[
"template_pointwise_attention"
],
**
config
[
"template_pointwise_attention"
],
)
)
def
forward
(
self
,
def
forward
(
batch
,
self
,
batch
,
z
,
z
,
pair_mask
,
pair_mask
,
templ_dim
,
templ_dim
,
chunk_size
,
chunk_size
,
_mask_trans
=
True
_mask_trans
=
True
):
):
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
...
@@ -706,7 +619,7 @@ class TemplatePairEmbedderMultimer(nn.Module):
...
@@ -706,7 +619,7 @@ class TemplatePairEmbedderMultimer(nn.Module):
c_dgram
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
c_aatype
:
int
,
):
):
super
().
__init__
()
super
(
TemplatePairEmbedderMultimer
,
self
).
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
...
@@ -765,7 +678,7 @@ class TemplateSingleEmbedderMultimer(nn.Module):
...
@@ -765,7 +678,7 @@ class TemplateSingleEmbedderMultimer(nn.Module):
c_in
:
int
,
c_in
:
int
,
c_m
:
int
,
c_m
:
int
,
):
):
super
().
__init__
()
super
(
TemplateSingleEmbedderMultimer
,
self
).
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_m
)
self
.
template_single_embedder
=
Linear
(
c_in
,
c_m
)
self
.
template_projector
=
Linear
(
c_m
,
c_m
)
self
.
template_projector
=
Linear
(
c_m
,
c_m
)
...
...
openfold/model/evoformer.py
View file @
d8ee9c5f
...
@@ -117,34 +117,19 @@ class MSATransition(nn.Module):
...
@@ -117,34 +117,19 @@ class MSATransition(nn.Module):
return
m
return
m
class
EvoformerBlockCore
(
nn
.
Module
):
class
PairStack
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
transition_n
:
int
,
pair_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
_is_extra_msa_stack
:
bool
=
False
,
):
):
super
(
EvoformerBlockCore
,
self
).
__init__
()
super
(
PairStack
,
self
).
__init__
()
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
c_z
,
c_z
,
...
@@ -178,25 +163,15 @@ class EvoformerBlockCore(nn.Module):
...
@@ -178,25 +163,15 @@ class EvoformerBlockCore(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
z
=
z
+
self
.
ps_dropout_row_layer
(
...
@@ -209,7 +184,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -209,7 +184,7 @@ class EvoformerBlockCore(nn.Module):
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
)
)
return
m
,
z
return
z
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
(
nn
.
Module
):
...
@@ -225,11 +200,14 @@ class EvoformerBlock(nn.Module):
...
@@ -225,11 +200,14 @@ class EvoformerBlock(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
):
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
opm_first
=
opm_first
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_m
=
c_m
,
c_z
=
c_z
,
c_z
=
c_z
,
...
@@ -247,18 +225,26 @@ class EvoformerBlock(nn.Module):
...
@@ -247,18 +225,26 @@ class EvoformerBlock(nn.Module):
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
pair_stack
=
PairStack
(
c_z
=
c_z
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
)
)
def
forward
(
self
,
def
forward
(
self
,
...
@@ -269,17 +255,34 @@ class EvoformerBlock(nn.Module):
...
@@ -269,17 +255,34 @@ class EvoformerBlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_dropout_layer
(
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
m
=
m
+
self
.
msa_transition
(
z
,
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
msa_mask
=
msa_mask
,
)
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
if
not
self
.
opm_first
:
_mask_trans
=
_mask_trans
,
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
z
=
self
.
pair_stack
(
z
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
)
)
return
m
,
z
return
m
,
z
...
@@ -304,12 +307,14 @@ class ExtraMSABlock(nn.Module):
...
@@ -304,12 +307,14 @@ class ExtraMSABlock(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
):
):
super
(
ExtraMSABlock
,
self
).
__init__
()
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
opm_first
=
opm_first
self
.
ckpt
=
ckpt
self
.
ckpt
=
ckpt
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
...
@@ -330,13 +335,21 @@ class ExtraMSABlock(nn.Module):
...
@@ -330,13 +335,21 @@ class ExtraMSABlock(nn.Module):
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
pair_stack
=
PairStack
(
c_z
=
c_z
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
...
@@ -361,6 +374,11 @@ class ExtraMSABlock(nn.Module):
...
@@ -361,6 +374,11 @@ class ExtraMSABlock(nn.Module):
m1
+=
m2
m1
+=
m2
return
m1
return
m1
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
add
(
m
,
self
.
msa_dropout_layer
(
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
self
.
msa_att_row
(
...
@@ -377,8 +395,17 @@ class ExtraMSABlock(nn.Module):
...
@@ -377,8 +395,17 @@ class ExtraMSABlock(nn.Module):
def
fn
(
m
,
z
):
def
fn
(
m
,
z
):
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
,
z
=
self
.
core
(
m
=
add
(
m
,
self
.
msa_transition
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
if
not
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
z
=
self
.
pair_stack
(
z
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
return
m
,
z
return
m
,
z
...
@@ -414,6 +441,7 @@ class EvoformerStack(nn.Module):
...
@@ -414,6 +441,7 @@ class EvoformerStack(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
...
@@ -475,6 +503,7 @@ class EvoformerStack(nn.Module):
...
@@ -475,6 +503,7 @@ class EvoformerStack(nn.Module):
transition_n
=
transition_n
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
opm_first
=
opm_first
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
)
)
...
@@ -555,6 +584,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -555,6 +584,7 @@ class ExtraMSAStack(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
...
@@ -581,6 +611,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -581,6 +611,7 @@ class ExtraMSAStack(nn.Module):
transition_n
=
transition_n
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
opm_first
=
opm_first
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
...
...
openfold/model/structure_module.py
View file @
d8ee9c5f
...
@@ -169,8 +169,8 @@ class PointProjection(nn.Module):
...
@@ -169,8 +169,8 @@ class PointProjection(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
activations
:
torch
.
Tensor
,
activations
:
torch
.
Tensor
,
rigids
:
Rigid3Array
,
rigids
:
Union
[
Rigid
,
Rigid3Array
]
,
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
]]:
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
]
,
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]:
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
points_local
=
points_local
.
reshape
(
...
@@ -181,8 +181,9 @@ class PointProjection(nn.Module):
...
@@ -181,8 +181,9 @@ class PointProjection(nn.Module):
points_local
=
torch
.
split
(
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
)
points_local
=
Vec3Array
(
*
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
if
(
self
.
return_local_points
):
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
,
points_local
...
@@ -285,7 +286,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -285,7 +286,7 @@ class InvariantPointAttention(nn.Module):
self
,
self
,
s
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
r
:
Rigid
,
r
:
Union
[
Rigid
,
Rigid3Array
]
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -340,9 +341,6 @@ class InvariantPointAttention(nn.Module):
...
@@ -340,9 +341,6 @@ class InvariantPointAttention(nn.Module):
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
kv_pts
=
self
.
linear_kv_points
(
s
,
r
)
kv_pts
=
self
.
linear_kv_points
(
s
,
r
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
# [*, N_res, H, P_q/P_v, 3]
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
k_pts
,
v_pts
=
torch
.
split
(
...
@@ -364,10 +362,16 @@ class InvariantPointAttention(nn.Module):
...
@@ -364,10 +362,16 @@ class InvariantPointAttention(nn.Module):
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
if
self
.
is_multimer
:
pt_att
=
q_pts
.
unsqueeze
(
-
3
)
-
k_pts
.
unsqueeze
(
-
4
)
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
...
@@ -399,20 +403,42 @@ class InvariantPointAttention(nn.Module):
...
@@ -399,20 +403,42 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
# As DeepMind explains, this manual matmul ensures that the operation
if
self
.
is_multimer
:
# happens in float32.
# As DeepMind explains, this manual matmul ensures that the operation
# [*, N_res, H, P_v]
# happens in float32.
o_pt
=
v_pts
*
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
# [*, N_res, H, P_v]
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
o_pt
=
v_pts
[...,
None
,
:,
:,
:]
*
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# [*, N_res, H, P_v]
# [*, N_res, H, P_v]
o_pt
=
r
[...,
None
,
None
].
apply_inverse_to_point
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
apply_inverse_to_point
(
o_pt
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
)
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
...
@@ -617,7 +643,10 @@ class StructureModule(nn.Module):
...
@@ -617,7 +643,10 @@ class StructureModule(nn.Module):
self
.
dropout_rate
,
self
.
dropout_rate
,
)
)
self
.
bb_update
=
QuatRigid
(
self
.
c_s
,
full_quat
=
False
)
if
self
.
is_multimer
:
self
.
bb_update
=
QuatRigid
(
self
.
c_s
,
full_quat
=
False
)
else
:
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
self
.
angle_resnet
=
AngleResnet
(
self
.
angle_resnet
=
AngleResnet
(
self
.
c_s
,
self
.
c_s
,
...
...
openfold/model/template.py
View file @
d8ee9c5f
...
@@ -141,6 +141,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -141,6 +141,7 @@ class TemplatePairStackBlock(nn.Module):
no_heads
:
int
,
no_heads
:
int
,
pair_transition_n
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
dropout_rate
:
float
,
tri_mul_first
:
bool
,
inf
:
float
,
inf
:
float
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
=
pair_transition_n
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
inf
=
inf
self
.
inf
=
inf
self
.
tri_mul_first
=
tri_mul_first
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
...
@@ -184,6 +186,38 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -184,6 +186,38 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
self
.
pair_transition_n
,
)
)
def
tri_att_start_end
(
self
,
single
,
single_mask
,
chunk_size
):
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
return
single
def
tri_mul_out_in
(
self
,
single
,
single_mask
):
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
return
single
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
...
@@ -200,32 +234,17 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -200,32 +234,17 @@ class TemplatePairStackBlock(nn.Module):
single
=
single_templates
[
i
]
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
if
self
.
tri_mul_first
:
self
.
tri_att_start
(
single
=
self
.
tri_att_start_end
(
single
=
self
.
tri_mul_out_in
(
single
=
single
,
single
,
single_mask
=
single_mask
),
chunk_size
=
chunk_size
,
single_mask
=
single_mask
,
mask
=
single_mask
chunk_size
=
chunk_size
)
)
else
:
)
single
=
self
.
tri_mul_out_in
(
single
=
self
.
tri_att_start_end
(
single
=
single
,
single
=
single
+
self
.
dropout_col
(
single_mask
=
single_mask
,
self
.
tri_att_end
(
chunk_size
=
chunk_size
),
single
,
single_mask
=
single_mask
)
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
=
single
+
self
.
pair_transition
(
single
,
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
mask
=
single_mask
if
_mask_trans
else
None
,
...
@@ -252,6 +271,7 @@ class TemplatePairStack(nn.Module):
...
@@ -252,6 +271,7 @@ class TemplatePairStack(nn.Module):
no_heads
,
no_heads
,
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
tri_mul_first
,
blocks_per_ckpt
,
blocks_per_ckpt
,
inf
=
1e9
,
inf
=
1e9
,
**
kwargs
,
**
kwargs
,
...
@@ -287,6 +307,7 @@ class TemplatePairStack(nn.Module):
...
@@ -287,6 +307,7 @@ class TemplatePairStack(nn.Module):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_transition_n
=
pair_transition_n
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
tri_mul_first
=
tri_mul_first
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
...
openfold/utils/feats.py
View file @
d8ee9c5f
...
@@ -18,7 +18,7 @@ import math
...
@@ -18,7 +18,7 @@ import math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Dict
,
Union
from
openfold.np
import
protein
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
import
openfold.np.residue_constants
as
rc
...
@@ -179,11 +179,11 @@ def build_extra_msa_feat(batch):
...
@@ -179,11 +179,11 @@ def build_extra_msa_feat(batch):
batch
[
"extra_has_deletion"
].
unsqueeze
(
-
1
),
batch
[
"extra_has_deletion"
].
unsqueeze
(
-
1
),
batch
[
"extra_deletion_value"
].
unsqueeze
(
-
1
),
batch
[
"extra_deletion_value"
].
unsqueeze
(
-
1
),
]
]
return
msa_feat
return
torch
.
cat
(
msa_feat
,
dim
=-
1
)
def
torsion_angles_to_frames
(
def
torsion_angles_to_frames
(
r
:
Rigid
,
r
:
Union
[
Rigid
,
rigid_matrix_vector
.
Rigid3Array
]
,
alpha
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
...
@@ -220,8 +220,14 @@ def torsion_angles_to_frames(
...
@@ -220,8 +220,14 @@ def torsion_angles_to_frames(
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
if
isinstance
(
r
,
Rigid
):
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
rigid_type
=
Rigid
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
else
:
rigid_type
=
rigid_matrix_vector
.
Rigid3Array
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -232,7 +238,7 @@ def torsion_angles_to_frames(
...
@@ -232,7 +238,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
rigid_
matrix_vector
.
Rigid3Array
.
cat
(
all_frames_to_bb
=
rigid_
type
.
cat
(
[
[
all_frames
[...,
:
5
],
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
@@ -248,7 +254,7 @@ def torsion_angles_to_frames(
...
@@ -248,7 +254,7 @@ def torsion_angles_to_frames(
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
rigid_matrix_vector
.
Rigid3Array
,
r
:
Union
[
Rigid
,
rigid_matrix_vector
.
Rigid3Array
]
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
default_frames
,
default_frames
,
group_idx
,
group_idx
,
...
@@ -277,6 +283,8 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -277,6 +283,8 @@ def frames_and_literature_positions_to_atom14_pos(
# [*, N, 14]
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
atom_mask
=
atom_mask
[
aatype
,
...]
if
isinstance
(
r
,
Rigid
):
atom_mask
=
atom_mask
.
unsqueeze
(
-
1
)
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
d8ee9c5f
...
@@ -142,7 +142,7 @@ class Rigid3Array:
...
@@ -142,7 +142,7 @@ class Rigid3Array:
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
return
Rigid3Ar
r
ay
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
return
Rigid3Array
(
...
@@ -174,3 +174,6 @@ class Rigid3Array:
...
@@ -174,3 +174,6 @@ class Rigid3Array:
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
)
return
cls
(
rotation
,
translation
)
return
cls
(
rotation
,
translation
)
def
cuda
(
self
)
->
Rigid3Array
:
return
Rigid3Array
.
from_tensor_4x4
(
self
.
to_tensor_4x4
().
cuda
())
openfold/utils/geometry/rotation_matrix.py
View file @
d8ee9c5f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
from
typing
import
List
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
...
@@ -172,10 +173,10 @@ class Rot3Array:
...
@@ -172,10 +173,10 @@ class Rot3Array:
"""Construct Rot3Array from components of quaternion."""
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
if
normalize
:
inv_norm
=
torch
.
rsqrt
(
eps
+
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
)
inv_norm
=
torch
.
rsqrt
(
eps
+
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
)
w
*
=
inv_norm
w
=
w
*
inv_norm
x
*
=
inv_norm
x
=
x
*
inv_norm
y
*
=
inv_norm
y
=
y
*
inv_norm
z
*
=
inv_norm
z
=
z
*
inv_norm
xx
=
1
-
2
*
(
y
**
2
+
z
**
2
)
xx
=
1
-
2
*
(
y
**
2
+
z
**
2
)
xy
=
2
*
(
x
*
y
-
w
*
z
)
xy
=
2
*
(
x
*
y
-
w
*
z
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
...
...
openfold/utils/geometry/vector.py
View file @
d8ee9c5f
...
@@ -110,7 +110,7 @@ class Vec3Array:
...
@@ -110,7 +110,7 @@ class Vec3Array:
# To avoid NaN on the backward pass, we must use maximum before the sqrt
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
norm2
=
self
.
dot
(
self
)
if
epsilon
:
if
epsilon
:
norm2
=
torch
.
clamp
(
norm2
,
m
ax
=
epsilon
**
2
)
norm2
=
torch
.
clamp
(
norm2
,
m
in
=
epsilon
**
2
)
return
torch
.
sqrt
(
norm2
)
return
torch
.
sqrt
(
norm2
)
def
norm2
(
self
):
def
norm2
(
self
):
...
...
openfold/utils/import_weights.py
View file @
d8ee9c5f
...
@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
...
@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
raise
raise
def
get_translation_dict
(
model
,
is_multimer
=
False
):
def
get_translation_dict
(
model
,
version
,
is_multimer
=
False
):
#######################
#######################
# Some templates
# Some templates
#######################
#######################
...
@@ -247,7 +247,7 @@ def get_translation_dict(model, is_multimer=False):
...
@@ -247,7 +247,7 @@ def get_translation_dict(model, is_multimer=False):
)
)
IPAParams
=
lambda
ipa
:
{
IPAParams
=
lambda
ipa
:
{
"q_scalar
_projection
"
:
LinearParams
(
ipa
.
linear_q
),
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
.
linear
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
.
linear
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
.
linear
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
.
linear
),
...
@@ -331,19 +331,19 @@ def get_translation_dict(model, is_multimer=False):
...
@@ -331,19 +331,19 @@ def get_translation_dict(model, is_multimer=False):
b
.
msa_att_row
b
.
msa_att_row
),
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
OuterProductMeanParams
(
b
.
outer_product_mean
),
"triangle_multiplication_outgoing"
:
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
TriMulOutParams
(
b
.
pair_stack
.
tri_mul_out
),
"triangle_multiplication_incoming"
:
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
core
.
tri_mul_in
),
TriMulInParams
(
b
.
pair_stack
.
tri_mul_in
),
"triangle_attention_starting_node"
:
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
TriAttParams
(
b
.
pair_stack
.
tri_att_start
),
"triangle_attention_ending_node"
:
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
TriAttParams
(
b
.
pair_stack
.
tri_att_end
),
"pair_transition"
:
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
PairTransitionParams
(
b
.
pair_stack
.
pair_transition
),
}
}
return
d
return
d
...
@@ -584,17 +584,6 @@ def get_translation_dict(model, is_multimer=False):
...
@@ -584,17 +584,6 @@ def get_translation_dict(model, is_multimer=False):
},
},
}
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
is_multimer
=
(
"multimer"
in
version
)
)
no_templ
=
[
no_templ
=
[
"model_3"
,
"model_3"
,
"model_4"
,
"model_4"
,
...
@@ -615,6 +604,18 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -615,6 +604,18 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
version
,
is_multimer
=
(
"multimer"
in
version
)
)
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
flat
=
_process_translations_dict
(
translations
)
flat
=
_process_translations_dict
(
translations
)
...
...
openfold/utils/loss.py
View file @
d8ee9c5f
...
@@ -636,9 +636,7 @@ def compute_tm(
...
@@ -636,9 +636,7 @@ def compute_tm(
)
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
soft_n
=
torch
.
sum
(
residue_weights
,
dim
=-
1
).
to
(
torch
.
int32
)
clipped_n
=
max
(
torch
.
sum
(
residue_weights
),
19
)
other
=
n
.
new_zeros
()
+
19
clipped_n
=
torch
.
max
(
soft_n
,
other
,
dim
=-
1
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.0
/
3
)
-
1.8
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.0
/
3
)
-
1.8
...
...
openfold/utils/rigid_utils.py
View file @
d8ee9c5f
...
@@ -986,6 +986,16 @@ class Rigid:
...
@@ -986,6 +986,16 @@ class Rigid:
"""
"""
return
self
.
_trans
.
device
return
self
.
_trans
.
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return
self
.
_rots
.
dtype
def
get_rots
(
self
)
->
Rotation
:
def
get_rots
(
self
)
->
Rotation
:
"""
"""
Getter for the rotation.
Getter for the rotation.
...
...
tests/compare_utils.py
View file @
d8ee9c5f
...
@@ -46,26 +46,26 @@ def import_alphafold():
...
@@ -46,26 +46,26 @@ def import_alphafold():
def
get_alphafold_config
():
def
get_alphafold_config
():
config
=
alphafold
.
model
.
config
.
model_config
(
"model_1_ptm"
)
# noqa
config
=
alphafold
.
model
.
config
.
model_config
(
consts
.
model
)
# noqa
config
.
model
.
global_config
.
deterministic
=
True
config
.
model
.
global_config
.
deterministic
=
True
return
config
return
config
_param_path
=
"openfold/resources/params/params_
model_1_ptm
.npz"
_param_path
=
f
"openfold/resources/params/params_
{
consts
.
model
}
.npz"
_model
=
None
_model
=
None
def
get_global_pretrained_openfold
():
def
get_global_pretrained_openfold
():
global
_model
global
_model
if
_model
is
None
:
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
))
_model
=
AlphaFold
(
model_config
(
consts
.
model
))
_model
=
_model
.
eval
()
_model
=
_model
.
eval
()
if
not
os
.
path
.
exists
(
_param_path
):
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
installation script before running tests."""
)
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
consts
.
model
)
_model
=
_model
.
cuda
()
_model
=
_model
.
cuda
()
return
_model
return
_model
...
...
tests/config.py
View file @
d8ee9c5f
...
@@ -2,6 +2,9 @@ import ml_collections as mlc
...
@@ -2,6 +2,9 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
consts
=
mlc
.
ConfigDict
(
{
{
"model"
:
"model_1_multimer_v2"
,
# monomer:model_1_ptm, multimer: model_1_multimer_v2
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"batch_size"
:
2
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
11
,
"n_seq"
:
13
,
"n_seq"
:
13
,
...
@@ -15,6 +18,7 @@ consts = mlc.ConfigDict(
...
@@ -15,6 +18,7 @@ consts = mlc.ConfigDict(
"c_s"
:
384
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_t"
:
64
,
"c_e"
:
64
,
"c_e"
:
64
,
"msa_logits"
:
22
# monomer: 23, multimer: 22
}
}
)
)
...
...
tests/data_utils.py
View file @
d8ee9c5f
...
@@ -12,9 +12,31 @@
...
@@ -12,9 +12,31 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
random
import
randint
import
numpy
as
np
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
from
scipy.spatial.transform
import
Rotation
from
tests.config
import
consts
def
random_asym_ids
(
n_res
,
split_chains
=
True
,
min_chain_len
=
4
):
n_chain
=
randint
(
1
,
n_res
//
min_chain_len
)
if
consts
.
is_multimer
else
1
if
not
split_chains
:
return
[
0
]
*
n_res
assert
n_res
>=
n_chain
pieces
=
[]
asym_ids
=
[]
for
idx
in
range
(
n_chain
-
1
):
piece
=
randint
(
min_chain_len
,
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
))
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
n_chain
-
1
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
b
=
[]
...
@@ -39,6 +61,11 @@ def random_template_feats(n_templ, n, batch_size=None):
...
@@ -39,6 +61,11 @@ def random_template_feats(n_templ, n, batch_size=None):
}
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
if
consts
.
is_multimer
:
asym_ids
=
np
.
array
(
random_asym_ids
(
n
))
batch
[
"asym_id"
]
=
np
.
tile
(
asym_ids
[
np
.
newaxis
,
:],
(
*
b
,
n_templ
,
1
))
return
batch
return
batch
...
...
tests/test_data_pipeline.py
View file @
d8ee9c5f
...
@@ -15,19 +15,13 @@
...
@@ -15,19 +15,13 @@
import
pickle
import
pickle
import
shutil
import
shutil
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.data.data_pipeline
import
DataPipeline
from
openfold.data.data_pipeline
import
DataPipeline
from
openfold.data.templates
import
TemplateHitFeaturizer
from
openfold.data.templates
import
HhsearchHitFeaturizer
,
HmmsearchHitFeaturizer
from
openfold.model.embedders
import
(
InputEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
...
@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with
open
(
"tests/test_data/alphafold_feature_dict.pickle"
,
"rb"
)
as
fp
:
with
open
(
"tests/test_data/alphafold_feature_dict.pickle"
,
"rb"
)
as
fp
:
alphafold_feature_dict
=
pickle
.
load
(
fp
)
alphafold_feature_dict
=
pickle
.
load
(
fp
)
template_featurizer
=
TemplateHitFeaturizer
(
if
consts
.
is_multimer
:
mmcif_dir
=
"tests/test_data/mmcifs"
,
# template_featurizer = HmmsearchHitFeaturizer(
max_template_date
=
"2021-12-20"
,
# mmcif_dir="tests/test_data/mmcifs",
max_hits
=
20
,
# max_template_date="2021-12-20",
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
# max_hits=20,
_zero_center_positions
=
False
,
# kalign_binary_path=shutil.which("kalign"),
)
# _zero_center_positions=False,
# )
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
else
:
template_featurizer
=
HhsearchHitFeaturizer
(
mmcif_dir
=
"tests/test_data/mmcifs"
,
max_template_date
=
"2021-12-20"
,
max_hits
=
20
,
kalign_binary_path
=
shutil
.
which
(
"kalign"
),
_zero_center_positions
=
False
,
)
data_pipeline
=
DataPipeline
(
data_pipeline
=
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
...
...
tests/test_data_transforms.py
View file @
d8ee9c5f
import
copy
import
gzip
import
gzip
import
os
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
...
@@ -178,7 +174,7 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -178,7 +174,7 @@ class TestDataTransforms(unittest.TestCase):
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
make_hhblits_profile
(
protein
)
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
)
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
,
seed
=
42
)
assert
'bert_mask'
in
protein
assert
'bert_mask'
in
protein
assert
'true_msa'
in
protein
assert
'true_msa'
in
protein
assert
'msa'
in
protein
assert
'msa'
in
protein
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment