Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1350 additions
and
1389 deletions
+1350
-1389
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+105
-113
openfold/utils/loss.py
openfold/utils/loss.py
+414
-471
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+55
-52
tests/compare_utils.py
tests/compare_utils.py
+21
-22
tests/config.py
tests/config.py
+17
-15
tests/data_utils.py
tests/data_utils.py
+23
-17
tests/test_embedders.py
tests/test_embedders.py
+12
-17
tests/test_evoformer.py
tests/test_evoformer.py
+65
-47
tests/test_feats.py
tests/test_feats.py
+77
-82
tests/test_import_weights.py
tests/test_import_weights.py
+29
-17
tests/test_loss.py
tests/test_loss.py
+255
-243
tests/test_model.py
tests/test_model.py
+23
-29
tests/test_msa.py
tests/test_msa.py
+60
-58
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+25
-21
tests/test_pair_transition.py
tests/test_pair_transition.py
+23
-23
tests/test_structure_module.py
tests/test_structure_module.py
+54
-66
tests/test_template.py
tests/test_template.py
+34
-32
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+19
-23
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+17
-18
tests/test_utils.py
tests/test_utils.py
+22
-23
No files found.
openfold/utils/import_weights.py
View file @
07e64267
...
@@ -26,28 +26,25 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
...
@@ -26,28 +26,25 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
# With Param, a poor man's enum with attributes (Rust-style)
class
ParamType
(
Enum
):
class
ParamType
(
Enum
):
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
)
)
LinearWeightMHA
=
partial
(
LinearWeightMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
)
)
LinearMHAOutputWeight
=
partial
(
LinearMHAOutputWeight
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
)
LinearBiasMHA
=
partial
(
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
))
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
)
)
LinearWeightOPM
=
partial
(
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
)
Other
=
partial
(
Other
=
partial
(
lambda
w
:
w
)
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
self
.
transformation
=
fn
self
.
transformation
=
fn
@
dataclass
@
dataclass
class
Param
:
class
Param
:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
...
@@ -58,16 +55,17 @@ class Param:
...
@@ -58,16 +55,17 @@ class Param:
def
_process_translations_dict
(
d
,
top_layer
=
True
):
def
_process_translations_dict
(
d
,
top_layer
=
True
):
flat
=
{}
flat
=
{}
for
k
,
v
in
d
.
items
():
for
k
,
v
in
d
.
items
():
if
(
type
(
v
)
==
dict
)
:
if
type
(
v
)
==
dict
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
''
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
""
sub_flat
=
{
sub_flat
=
{
(
prefix
+
'/'
.
join
([
k
,
k_prime
])):
v_prime
(
prefix
+
"/"
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
for
k_prime
,
v_prime
in
_process_translations_dict
(
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
v
,
top_layer
=
False
).
items
()
}
}
flat
.
update
(
sub_flat
)
flat
.
update
(
sub_flat
)
else
:
else
:
k
=
'/'
+
k
if
not
top_layer
else
k
k
=
"/"
+
k
if
not
top_layer
else
k
flat
[
k
]
=
v
flat
[
k
]
=
v
return
flat
return
flat
...
@@ -75,29 +73,29 @@ def _process_translations_dict(d, top_layer=True):
...
@@ -75,29 +73,29 @@ def _process_translations_dict(d, top_layer=True):
def
stacked
(
param_dict_list
,
out
=
None
):
def
stacked
(
param_dict_list
,
out
=
None
):
"""
"""
Args:
Args:
param_dict_list:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
"parallel" Params). There must be at least one dict
in the list.
in the list.
"""
"""
if
(
out
is
None
)
:
if
out
is
None
:
out
=
{}
out
=
{}
template
=
param_dict_list
[
0
]
template
=
param_dict_list
[
0
]
for
k
,
_
in
template
.
items
():
for
k
,
_
in
template
.
items
():
v
=
[
d
[
k
]
for
d
in
param_dict_list
]
v
=
[
d
[
k
]
for
d
in
param_dict_list
]
if
(
type
(
v
[
0
])
is
dict
)
:
if
type
(
v
[
0
])
is
dict
:
out
[
k
]
=
{}
out
[
k
]
=
{}
stacked
(
v
,
out
=
out
[
k
])
stacked
(
v
,
out
=
out
[
k
])
elif
(
type
(
v
[
0
])
is
Param
)
:
elif
type
(
v
[
0
])
is
Param
:
stacked_param
=
Param
(
stacked_param
=
Param
(
param
=
[
param
.
param
for
param
in
v
],
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
param_type
=
v
[
0
].
param_type
,
stacked
=
True
stacked
=
True
,
)
)
out
[
k
]
=
stacked_param
out
[
k
]
=
stacked_param
return
out
return
out
...
@@ -107,12 +105,12 @@ def assign(translation_dict, orig_weights):
...
@@ -107,12 +105,12 @@ def assign(translation_dict, orig_weights):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
weights
=
torch
.
as_tensor
(
orig_weights
[
k
])
weights
=
torch
.
as_tensor
(
orig_weights
[
k
])
ref
,
param_type
=
param
.
param
,
param
.
param_type
ref
,
param_type
=
param
.
param
,
param
.
param_type
if
(
param
.
stacked
)
:
if
param
.
stacked
:
weights
=
torch
.
unbind
(
weights
,
0
)
weights
=
torch
.
unbind
(
weights
,
0
)
else
:
else
:
weights
=
[
weights
]
weights
=
[
weights
]
ref
=
[
ref
]
ref
=
[
ref
]
try
:
try
:
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
for
p
,
w
in
zip
(
ref
,
weights
):
for
p
,
w
in
zip
(
ref
,
weights
):
...
@@ -121,36 +119,25 @@ def assign(translation_dict, orig_weights):
...
@@ -121,36 +119,25 @@ def assign(translation_dict, orig_weights):
print
(
k
)
print
(
k
)
print
(
ref
[
0
].
shape
)
print
(
ref
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
raise
raise
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
data
=
np
.
load
(
npz_path
)
#######################
#######################
# Some templates
# Some templates
#######################
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearWeight
=
lambda
l
:
(
LinearBias
=
lambda
l
:
(
Param
(
l
))
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
)
)
LinearBias
=
lambda
l
:
(
Param
(
l
)
)
LinearWeightMHA
=
lambda
l
:
(
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
)
)
LinearBiasMHA
=
lambda
b
:
(
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
)
)
LinearWeightOPM
=
lambda
l
:
(
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
)
)
LinearParams
=
lambda
l
:
{
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"weights"
:
LinearWeight
(
l
.
weight
),
...
@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"key_w"
:
LinearWeightMHA
(
att
.
linear_k
.
weight
),
"key_w"
:
LinearWeightMHA
(
att
.
linear_k
.
weight
),
"value_w"
:
LinearWeightMHA
(
att
.
linear_v
.
weight
),
"value_w"
:
LinearWeightMHA
(
att
.
linear_v
.
weight
),
"output_w"
:
Param
(
"output_w"
:
Param
(
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
),
),
"output_b"
:
LinearBias
(
att
.
linear_o
.
bias
),
"output_b"
:
LinearBias
(
att
.
linear_o
.
bias
),
}
}
...
@@ -205,7 +193,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -205,7 +193,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# see commit b88f8da on the Alphafold repo
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
...
@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams
=
lambda
matt
:
{
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
,
}
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
...
@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"trainable_point_weights"
:
"trainable_point_weights"
:
Param
(
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
}
...
@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}
}
def
EvoformerBlockParams
(
b
,
is_extra_msa
=
False
):
def
EvoformerBlockParams
(
b
,
is_extra_msa
=
False
):
if
(
is_extra_msa
)
:
if
is_extra_msa
:
col_att_name
=
"msa_column_global_attention"
col_att_name
=
"msa_column_global_attention"
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
else
:
...
@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
d
=
{
d
=
{
"msa_row_attention_with_pair_bias"
:
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
MSAAttPairBiasParams
(
b
.
msa_att_row
),
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
...
@@ -316,10 +306,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -316,10 +306,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
}
}
,
}
}
############################
############################
# translations dict overflow
# translations dict overflow
############################
############################
...
@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
)
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks_params
=
stacked
(
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
[
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
]
)
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
(
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
[
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
]
)
translations
=
{
translations
=
{
"evoformer"
:
{
"evoformer"
:
{
...
@@ -346,101 +331,108 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -346,101 +331,108 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
"prev_msa_first_row_norm"
:
LayerNormParams
(
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
model
.
recycling_embedder
.
layer_norm_m
"prev_pair_norm"
:
),
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"prev_pair_norm"
:
LayerNormParams
(
"pair_activiations"
:
model
.
recycling_embedder
.
layer_norm_z
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"template_embedding"
:
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
"embedding2d"
:
LinearParams
(
LinearParams
(
model
.
template_pair_embedder
.
linear
),
model
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
"__layer_stack_no_state"
:
tps_blocks_params
,
},
},
"output_layer_norm"
:
"output_layer_norm"
:
LayerNormParams
(
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
),
model
.
template_pair_stack
.
layer_norm
),
},
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
},
},
"extra_msa_activations"
:
"extra_msa_activations"
:
LinearParams
(
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
"template_single_embedding"
:
LinearParams
(
LinearParams
(
model
.
template_angle_embedder
.
linear_1
),
model
.
template_angle_embedder
.
linear_1
"template_projection"
:
),
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
},
"structure_module"
:
{
"structure_module"
:
{
"single_layer_norm"
:
"single_layer_norm"
:
LayerNormParams
(
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
model
.
structure_module
.
layer_norm_s
"initial_projection"
:
),
LinearParams
(
model
.
structure_module
.
linear_in
),
"initial_projection"
:
LinearParams
(
"pair_layer_norm"
:
model
.
structure_module
.
linear_in
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
},
"predicted_lddt_head"
:
{
"predicted_lddt_head"
:
{
"input_layer_norm"
:
"input_layer_norm"
:
LayerNormParams
(
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
model
.
aux_heads
.
plddt
.
layer_norm
"act_0"
:
),
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
},
"distogram_head"
:
{
"distogram_head"
:
{
"half_logits"
:
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
},
"experimentally_resolved_head"
:
{
"experimentally_resolved_head"
:
{
"logits"
:
"logits"
:
LinearParams
(
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
},
"masked_msa_head"
:
{
"masked_msa_head"
:
{
"logits"
:
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
},
}
}
no_templ
=
[
no_templ
=
[
"model_3"
,
"model_3"
,
"model_4"
,
"model_4"
,
"model_5"
,
"model_5"
,
"model_3_ptm"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_4_ptm"
,
"model_5_ptm"
,
"model_5_ptm"
,
]
]
if
(
version
in
no_templ
)
:
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
for
k
in
keys
:
if
(
"template_"
in
k
)
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
evo_dict
.
pop
(
k
)
if
(
"_ptm"
in
version
)
:
if
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
flat
=
_process_translations_dict
(
translations
)
flat
=
_process_translations_dict
(
translations
)
# Sanity check
# Sanity check
keys
=
list
(
data
.
keys
())
keys
=
list
(
data
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
#print(f"Incorrect: {incorrect}")
#
print(f"Incorrect: {incorrect}")
#print(f"Missing: {missing}")
#
print(f"Missing: {missing}")
assert
(
len
(
incorrect
)
==
0
)
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
# Set weights
assign
(
flat
,
data
)
assign
(
flat
,
data
)
openfold/utils/loss.py
View file @
07e64267
...
@@ -25,8 +25,8 @@ from openfold.np import residue_constants
...
@@ -25,8 +25,8 @@ from openfold.np import residue_constants
from
openfold.utils
import
feats
from
openfold.utils
import
feats
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
masked_mean
,
masked_mean
,
permute_final_dims
,
permute_final_dims
,
batched_gather
,
batched_gather
,
...
@@ -49,9 +49,9 @@ def sigmoid_cross_entropy(logits, labels):
...
@@ -49,9 +49,9 @@ def sigmoid_cross_entropy(logits, labels):
def
torsion_angle_loss
(
def
torsion_angle_loss
(
a
,
# [*, N, 7, 2]
a
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
a_alt_gt
,
# [*, N, 7, 2]
a_alt_gt
,
# [*, N, 7, 2]
):
):
# [*, N, 7]
# [*, N, 7]
norm
=
torch
.
norm
(
a
,
dim
=-
1
)
norm
=
torch
.
norm
(
a
,
dim
=-
1
)
...
@@ -81,7 +81,7 @@ def compute_fape(
...
@@ -81,7 +81,7 @@ def compute_fape(
positions_mask
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
length_scale
:
float
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
eps
=
1e-8
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [*, N_frames, N_pts, 3]
# [*, N_frames, N_pts, 3]
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
...
@@ -91,10 +91,10 @@ def compute_fape(
...
@@ -91,10 +91,10 @@ def compute_fape(
target_positions
[...,
None
,
:,
:],
target_positions
[...,
None
,
:,
:],
)
)
error_dist
=
torch
.
sqrt
(
error_dist
=
torch
.
sqrt
(
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
)
if
(
l1_clamp_distance
is
not
None
)
:
if
l1_clamp_distance
is
not
None
:
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
normed_error
=
error_dist
/
length_scale
...
@@ -111,7 +111,9 @@ def compute_fape(
...
@@ -111,7 +111,9 @@ def compute_fape(
#
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
...
@@ -126,14 +128,14 @@ def backbone_loss(
...
@@ -126,14 +128,14 @@ def backbone_loss(
backbone_affine_mask
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
loss_unit_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
[...,
None
,
:],
gt_aff
[...,
None
,
:],
...
@@ -145,7 +147,7 @@ def backbone_loss(
...
@@ -145,7 +147,7 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
)
)
if
(
use_clamped_fape
is
not
None
)
:
if
use_clamped_fape
is
not
None
:
unclamped_fape_loss
=
compute_fape
(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
[...,
None
,
:],
gt_aff
[...,
None
,
:],
...
@@ -158,9 +160,8 @@ def backbone_loss(
...
@@ -158,9 +160,8 @@ def backbone_loss(
eps
=
eps
,
eps
=
eps
,
)
)
fape_loss
=
(
fape_loss
=
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
fape_loss
*
use_clamped_fape
+
1
-
use_clamped_fape
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
)
# Take the mean over the layer dimension
# Take the mean over the layer dimension
...
@@ -177,42 +178,31 @@ def sidechain_loss(
...
@@ -177,42 +178,31 @@ def sidechain_loss(
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
length_scale
:
float
=
10.
,
length_scale
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
])
*
1.0
-
alt_naming_is_better
[...,
None
,
None
,
None
]
rigidgroups_gt_frames
+
)
*
rigidgroups_gt_frames
+
alt_naming_is_better
[
alt_naming_is_better
[...,
None
,
None
,
None
]
*
...,
None
,
None
,
None
rigidgroups_alt_gt_frames
]
*
rigidgroups_alt_gt_frames
)
# Steamroll the inputs
# Steamroll the inputs
sidechain_frames
=
sidechain_frames
[
-
1
]
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
*
batch_dims
,
-
1
,
4
,
4
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
*
batch_dims
,
-
1
,
3
)
renamed_atom14_gt_positions
=
renamed_atom14_gt_positions
.
view
(
renamed_atom14_gt_positions
=
renamed_atom14_gt_positions
.
view
(
*
batch_dims
,
-
1
,
3
*
batch_dims
,
-
1
,
3
)
)
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
)
*
batch_dims
,
-
1
)
fape
=
compute_fape
(
fape
=
compute_fape
(
sidechain_frames
,
sidechain_frames
,
...
@@ -235,19 +225,17 @@ def fape_loss(
...
@@ -235,19 +225,17 @@ def fape_loss(
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
)
)
sc_loss
=
sidechain_loss
(
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
out
[
"sm"
][
"positions"
],
**
{
**
batch
,
**
config
.
sidechain
}
**
{
**
batch
,
**
config
.
sidechain
}
,
)
)
return
(
return
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
)
def
supervised_chi_loss
(
def
supervised_chi_loss
(
...
@@ -264,10 +252,11 @@ def supervised_chi_loss(
...
@@ -264,10 +252,11 @@ def supervised_chi_loss(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
aatype
,
residue_constants
.
restype_num
+
1
,
)
)
chi_pi_periodic
=
torch
.
einsum
(
chi_pi_periodic
=
torch
.
einsum
(
"...ij,jk->ik"
,
"...ij,jk->ik"
,
residue_type_one_hot
.
type
(
angles_sin_cos
.
dtype
),
residue_type_one_hot
.
type
(
angles_sin_cos
.
dtype
),
angles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
),
angles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
),
)
)
...
@@ -276,11 +265,9 @@ def supervised_chi_loss(
...
@@ -276,11 +265,9 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
sq_chi_error
=
torch
.
sum
((
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error_shifted
=
torch
.
sum
(
sq_chi_error_shifted
=
torch
.
sum
(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
# The ol' switcheroo
...
@@ -295,14 +282,14 @@ def supervised_chi_loss(
...
@@ -295,14 +282,14 @@ def supervised_chi_loss(
loss
=
loss
+
chi_weight
*
sq_chi_loss
loss
=
loss
+
chi_weight
*
sq_chi_loss
angle_norm
=
torch
.
sqrt
(
angle_norm
=
torch
.
sqrt
(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
)
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
0
)
norm_error
=
norm_error
.
permute
(
norm_error
=
norm_error
.
permute
(
*
range
(
len
(
norm_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
*
range
(
len
(
norm_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
)
angle_norm_loss
=
masked_mean
(
angle_norm_loss
=
masked_mean
(
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
)
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
...
@@ -312,14 +299,13 @@ def supervised_chi_loss(
...
@@ -312,14 +299,13 @@ def supervised_chi_loss(
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_bins
=
logits
.
shape
[
-
1
]
num_bins
=
logits
.
shape
[
-
1
]
bin_width
=
1.
/
num_bins
bin_width
=
1.
0
/
num_bins
bounds
=
torch
.
arange
(
bounds
=
torch
.
arange
(
start
=
0.5
*
bin_width
,
end
=
1.0
,
step
=
bin_width
,
device
=
logits
.
device
start
=
0.5
*
bin_width
,
end
=
1.0
,
step
=
bin_width
,
device
=
logits
.
device
)
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
pred_lddt_ca
=
torch
.
sum
(
pred_lddt_ca
=
torch
.
sum
(
probs
*
probs
*
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
dim
=-
1
,
dim
=-
1
,
)
)
return
pred_lddt_ca
*
100
return
pred_lddt_ca
*
100
...
@@ -331,7 +317,7 @@ def lddt_loss(
...
@@ -331,7 +317,7 @@ def lddt_loss(
all_atom_positions
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
cutoff
:
float
=
15.
0
,
no_bins
:
int
=
50
,
no_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
max_resolution
:
float
=
3.0
,
...
@@ -339,55 +325,57 @@ def lddt_loss(
...
@@ -339,55 +325,57 @@ def lddt_loss(
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n
=
all_atom_mask
.
shape
[
-
2
]
n
=
all_atom_mask
.
shape
[
-
2
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
dmat_true
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
all_atom_positions
[...,
None
,
:]
-
all_atom_positions
[...,
None
,
:]
all_atom_positions
[...,
None
,
:,
:]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
dmat_pred
=
torch
.
sqrt
(
dmat_pred
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
all_atom_pred_pos
[...,
None
,
:]
-
all_atom_pred_pos
[...,
None
,
:]
all_atom_pred_pos
[...,
None
,
:,
:]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
)
**
2
,
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
dists_to_score
=
(
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
(
dmat_true
<
cutoff
)
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
*
all_atom_mask
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
(
1.0
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
)
)
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
score
=
(
score
=
(
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
+
+
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
+
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
)
)
score
=
score
*
0.25
score
=
score
*
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
norm
=
1.
0
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
score
=
score
.
detach
()
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
...
@@ -396,40 +384,39 @@ def lddt_loss(
...
@@ -396,40 +384,39 @@ def lddt_loss(
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
loss
=
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
)
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
))
)
)
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
<=
max_resolution
)
)
)
return
loss
return
loss
def
distogram_loss
(
def
distogram_loss
(
logits
,
logits
,
pseudo_beta
,
pseudo_beta
,
pseudo_beta_mask
,
pseudo_beta_mask
,
min_bin
=
2.3125
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
max_bin
=
21.6875
,
no_bins
=
64
,
no_bins
=
64
,
eps
=
1e-6
,
eps
=
1e-6
,
**
kwargs
,
**
kwargs
,
):
):
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
)
)
boundaries
=
boundaries
**
2
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
dists
=
torch
.
sum
(
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:]
dim
=-
1
,
)
**
2
,
keepdims
=
True
,
dim
=-
1
,
keepdims
=
True
)
)
true_bins
=
torch
.
sum
(
dists
>
boundaries
,
dim
=-
1
)
true_bins
=
torch
.
sum
(
dists
>
boundaries
,
dim
=-
1
)
...
@@ -442,7 +429,7 @@ def distogram_loss(
...
@@ -442,7 +429,7 @@ def distogram_loss(
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
# FP16-friendly sum. Equivalent to:
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom
=
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
))
mean
=
errors
*
square_mask
mean
=
errors
*
square_mask
...
@@ -450,7 +437,7 @@ def distogram_loss(
...
@@ -450,7 +437,7 @@ def distogram_loss(
mean
=
mean
/
denom
[...,
None
]
mean
=
mean
/
denom
[...,
None
]
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
return
mean
return
mean
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
...
@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
...
@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
return
(
return
(
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
bin_centers
[
-
1
]
bin_centers
[
-
1
]
,
)
)
...
@@ -480,7 +467,7 @@ def compute_predicted_aligned_error(
...
@@ -480,7 +467,7 @@ def compute_predicted_aligned_error(
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes aligned confidence metrics from logits.
"""Computes aligned confidence metrics from logits.
Args:
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
PredictedAlignedErrorHead.
...
@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
...
@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
max_predicted_aligned_error: [*] the maximum predicted error possible.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
"""
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
0
,
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
)
aligned_confidence_probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
aligned_confidence_probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
predicted_aligned_error
,
max_predicted_aligned_error
=
(
(
_calculate_expected_aligned_error
(
predicted_aligned_error
,
alignment_confidence_breaks
=
boundaries
,
max_predicted_aligned_error
,
aligned_distance_error_probs
=
aligned_confidence_probs
)
=
_calculate_expected_aligned_error
(
)
alignment_confidence_breaks
=
boundaries
,
aligned_distance_error_probs
=
aligned_confidence_probs
,
)
)
return
{
return
{
...
@@ -523,14 +508,11 @@ def compute_tm(
...
@@ -523,14 +508,11 @@ def compute_tm(
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
residue_weights
is
None
)
:
if
residue_weights
is
None
:
residue_weights
=
logits
.
new_ones
(
logits
.
shape
[
-
2
])
residue_weights
=
logits
.
new_ones
(
logits
.
shape
[
-
2
])
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
0
,
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
...
@@ -538,11 +520,11 @@ def compute_tm(
...
@@ -538,11 +520,11 @@ def compute_tm(
n
=
logits
.
shape
[
-
2
]
n
=
logits
.
shape
[
-
2
]
clipped_n
=
max
(
n
,
19
)
clipped_n
=
max
(
n
,
19
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
0
/
3
)
-
1.8
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
tm_per_bin
=
1.
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
tm_per_bin
=
1.
0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
...
@@ -554,12 +536,12 @@ def compute_tm(
...
@@ -554,12 +536,12 @@ def compute_tm(
def
tm_loss
(
def
tm_loss
(
logits
,
logits
,
final_affine_tensor
,
final_affine_tensor
,
backbone_affine_tensor
,
backbone_affine_tensor
,
backbone_affine_mask
,
backbone_affine_mask
,
resolution
,
resolution
,
max_bin
=
31
,
max_bin
=
31
,
no_bins
=
64
,
no_bins
=
64
,
min_resolution
:
float
=
0.1
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
max_resolution
:
float
=
3.0
,
eps
=
1e-8
,
eps
=
1e-8
,
...
@@ -573,25 +555,18 @@ def tm_loss(
...
@@ -573,25 +555,18 @@ def tm_loss(
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
dim
=-
1
)
)
sq_diff
=
sq_diff
.
detach
()
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
0
,
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
)
boundaries
=
boundaries
**
2
boundaries
=
boundaries
**
2
true_bins
=
torch
.
sum
(
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
logits
,
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
)
)
square_mask
=
(
square_mask
=
(
...
@@ -599,15 +574,14 @@ def tm_loss(
...
@@ -599,15 +574,14 @@ def tm_loss(
)
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.5
# hack to help FP16 training along
scale
=
0.5
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
scale
loss
=
loss
*
scale
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
<=
max_resolution
)
)
)
return
loss
return
loss
...
@@ -623,11 +597,11 @@ def between_residue_bond_loss(
...
@@ -623,11 +597,11 @@ def between_residue_bond_loss(
eps
=
1e-6
,
eps
=
1e-6
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Flat-bottom loss to penalize structural violations between residues.
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
...
@@ -638,7 +612,7 @@ def between_residue_bond_loss(
...
@@ -638,7 +612,7 @@ def between_residue_bond_loss(
of pdb distributions
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
of pdb distributions
Returns:
Returns:
Dict containing:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'c_n_loss_mean': Loss for peptide bond length violations
...
@@ -659,126 +633,116 @@ def between_residue_bond_loss(
...
@@ -659,126 +633,116 @@ def between_residue_bond_loss(
next_n_mask
=
pred_atom_mask
[...,
1
:,
0
]
next_n_mask
=
pred_atom_mask
[...,
1
:,
0
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
# Compute loss for the C--N bond.
# Compute loss for the C--N bond.
c_n_bond_length
=
torch
.
sqrt
(
c_n_bond_length
=
torch
.
sqrt
(
eps
+
eps
+
torch
.
sum
((
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
torch
.
sum
(
(
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
)
)
# The C-N bond to proline has slightly different length because of the ring.
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
next_is_proline
=
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
)
gt_length
=
(
gt_length
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
~
next_is_proline
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
)
*
residue_constants
.
between_res_bond_length_c_n
[
)
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
gt_stddev
=
(
gt_stddev
=
(
(
~
next_is_proline
)
*
~
next_is_proline
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
next_is_proline
*
0
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
)
1
c_n_bond_length_error
=
torch
.
sqrt
(
]
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
(
c_n_loss
=
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
)
c_n_violation_mask
=
mask
*
(
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
)
# Compute loss for the angles.
# Compute loss for the angles.
ca_c_bond_length
=
torch
.
sqrt
(
ca_c_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
)
)
n_ca_bond_length
=
torch
.
sqrt
(
n_ca_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
)
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
)
/
ca_c_bond_length
[...,
None
]
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
)
/
ca_c_bond_length
[...,
None
]
c_n_unit_vec
=
(
next_n_pos
-
this_c_pos
)
/
c_n_bond_length
[...,
None
]
c_n_unit_vec
=
(
next_n_pos
-
this_c_pos
)
/
c_n_bond_length
[...,
None
]
n_ca_unit_vec
=
(
next_ca_pos
-
next_n_pos
)
/
n_ca_bond_length
[...,
None
]
n_ca_unit_vec
=
(
next_ca_pos
-
next_n_pos
)
/
n_ca_bond_length
[...,
None
]
ca_c_n_cos_angle
=
torch
.
sum
(
c_ca_unit_vec
*
c_n_unit_vec
,
dim
=-
1
)
ca_c_n_cos_angle
=
torch
.
sum
(
c_ca_unit_vec
*
c_n_unit_vec
,
dim
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
ca_c_n_cos_angle_error
=
torch
.
sqrt
(
ca_c_n_cos_angle_error
=
torch
.
sqrt
(
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
)
)
ca_c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
ca_c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
(
ca_c_n_loss
=
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
ca_c_n_violation_mask
=
mask
*
(
(
tolerance_factor_hard
*
gt_stddev
))
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
c_n_ca_cos_angle
=
torch
.
sum
((
-
c_n_unit_vec
)
*
n_ca_unit_vec
,
dim
=-
1
)
c_n_ca_cos_angle
=
torch
.
sum
((
-
c_n_unit_vec
)
*
n_ca_unit_vec
,
dim
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
c_n_ca_cos_angle_error
=
torch
.
sqrt
(
c_n_ca_cos_angle_error
=
torch
.
sqrt
(
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
))
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
)
)
c_n_ca_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_ca_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
(
c_n_ca_loss
=
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
)
# Compute a per residue loss (equally distribute the loss to both
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
# neighbouring residues).
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
per_residue_loss_sum
=
(
ca_c_n_loss_per_residue
+
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
c_n_ca_loss_per_residue
)
)
per_residue_loss_sum
=
0.5
*
(
per_residue_loss_sum
=
0.5
*
(
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
)
)
# Compute hard violations.
# Compute hard violations.
violation_mask
=
torch
.
max
(
violation_mask
=
torch
.
max
(
torch
.
stack
(
torch
.
stack
(
[
[
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
dim
=-
2
,
dim
=-
2
,
),
),
dim
=-
2
dim
=-
2
,
)[
0
]
)[
0
]
violation_mask
=
torch
.
maximum
(
violation_mask
=
torch
.
maximum
(
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
,
)
)
return
{
return
{
'
c_n_loss_mean
'
:
c_n_loss
,
"
c_n_loss_mean
"
:
c_n_loss
,
'
ca_c_n_loss_mean
'
:
ca_c_n_loss
,
"
ca_c_n_loss_mean
"
:
ca_c_n_loss
,
'
c_n_ca_loss_mean
'
:
c_n_ca_loss
,
"
c_n_ca_loss_mean
"
:
c_n_ca_loss
,
'
per_residue_loss_sum
'
:
per_residue_loss_sum
,
"
per_residue_loss_sum
"
:
per_residue_loss_sum
,
'
per_residue_violation_mask
'
:
violation_mask
"
per_residue_violation_mask
"
:
violation_mask
,
}
}
...
@@ -792,12 +756,12 @@ def between_residue_clash_loss(
...
@@ -792,12 +756,12 @@ def between_residue_clash_loss(
eps
=
1e-10
,
eps
=
1e-10
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Loss to penalize steric clashes between residues.
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different peptides coming too close. This loss corresponds to the part with
different residues of
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
Args:
atom14_pred_positions: Predicted positions of atoms in
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
global prediction frame
...
@@ -807,7 +771,7 @@ def between_residue_clash_loss(
...
@@ -807,7 +771,7 @@ def between_residue_clash_loss(
residue_index: Residue index for given amino acid.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Returns:
Dict containing:
Dict containing:
* 'mean_loss': average clash loss
* 'mean_loss': average clash loss
...
@@ -816,33 +780,36 @@ def between_residue_clash_loss(
...
@@ -816,33 +780,36 @@ def between_residue_clash_loss(
shape (N, 14)
shape (N, 14)
"""
"""
fp_type
=
atom14_pred_positions
.
dtype
fp_type
=
atom14_pred_positions
.
dtype
# Create the distance matrix.
# Create the distance matrix.
# (N, N, 14, 14)
# (N, N, 14, 14)
dists
=
torch
.
sqrt
(
dists
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
)
dim
=-
1
)
**
2
,
dim
=-
1
,
)
)
)
# Create the mask for valid distances.
# Create the mask for valid distances.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
dists_mask
=
(
dists_mask
=
(
atom14_atom_exists
[...,
:,
None
,
:,
None
]
*
atom14_atom_exists
[...,
:,
None
,
:,
None
]
atom14_atom_exists
[...,
None
,
:,
None
,
:]
*
atom14_atom_exists
[...,
None
,
:,
None
,
:]
).
type
(
fp_type
)
).
type
(
fp_type
)
# Mask out all the duplicate entries in the lower triangular matrix.
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
# are handled separately.
dists_mask
=
dists_mask
*
(
dists_mask
=
dists_mask
*
(
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
)
)
# Backbone C--N bond between subsequent residues is no clash.
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
),
num_classes
=
14
residue_index
.
new_tensor
(
2
),
num_classes
=
14
...
@@ -860,74 +827,69 @@ def between_residue_clash_loss(
...
@@ -860,74 +827,69 @@ def between_residue_clash_loss(
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
neighbour_mask
=
(
(
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
:,
None
,
None
,
None
]
+
1
residue_index
[...,
None
,
:,
None
,
None
]
)
==
residue_index
[...,
None
,
:,
None
,
None
]
)
c_n_bonds
=
(
c_n_bonds
=
(
neighbour_mask
*
neighbour_mask
c_one_hot
[...,
None
,
None
,
:,
None
]
*
*
c_one_hot
[...,
None
,
None
,
:,
None
]
n_one_hot
[...,
None
,
None
,
None
,
:]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
)
)
dists_mask
=
dists_mask
*
(
1.
-
c_n_bonds
)
dists_mask
=
dists_mask
*
(
1.
0
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
# Disulfide bridge between two cysteines is no clash.
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"CYS"
]
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"CYS"
]
cys_sg_idx
=
cys
.
index
(
'
SG
'
)
cys_sg_idx
=
cys
.
index
(
"
SG
"
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
cys_sg_idx
.
reshape
(
cys_sg_idx
=
cys_sg_idx
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
1
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
1
).
squeeze
(
-
1
)
).
squeeze
(
-
1
)
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
cys_sg_idx
,
num_classes
=
14
)
disulfide_bonds
=
(
disulfide_bonds
=
(
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
cys_sg_one_hot
[...,
None
,
None
,
None
,
:])
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:]
dists_mask
=
dists_mask
*
(
1.
-
disulfide_bonds
)
)
dists_mask
=
dists_mask
*
(
1.0
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
dists_lower_bound
=
dists_mask
*
(
dists_lower_bound
=
dists_mask
*
(
atom14_atom_radius
[...,
:,
None
,
:,
None
]
+
atom14_atom_radius
[...,
:,
None
,
:,
None
]
atom14_atom_radius
[...,
None
,
:,
None
,
:]
+
atom14_atom_radius
[...,
None
,
:,
None
,
:]
)
)
# Compute the error.
# Compute the error.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
dists_to_low_error
=
dists_mask
*
torch
.
nn
.
functional
.
relu
(
dists_to_low_error
=
dists_mask
*
torch
.
nn
.
functional
.
relu
(
dists_lower_bound
-
overlap_tolerance_soft
-
dists
dists_lower_bound
-
overlap_tolerance_soft
-
dists
)
)
# Compute the mean loss.
# Compute the mean loss.
# shape ()
# shape ()
mean_loss
=
(
mean_loss
=
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
)
# Compute the per atom loss sum.
# Compute the per atom loss sum.
# shape (N, 14)
# shape (N, 14)
per_atom_loss_sum
=
(
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
))
)
)
# Compute the hard clash mask.
# Compute the hard clash mask.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
)
# Compute the per atom clash.
# Compute the per atom clash.
# shape (N, 14)
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
)
)
return
{
return
{
'
mean_loss
'
:
mean_loss
,
# shape ()
"
mean_loss
"
:
mean_loss
,
# shape ()
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
# shape (N, 14)
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
# shape (N, 14)
'
per_atom_clash_mask
'
:
per_atom_clash_mask
# shape (N, 14)
"
per_atom_clash_mask
"
:
per_atom_clash_mask
,
# shape (N, 14)
}
}
...
@@ -940,54 +902,53 @@ def within_residue_violations(
...
@@ -940,54 +902,53 @@ def within_residue_violations(
eps
=
1e-10
,
eps
=
1e-10
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Loss to penalize steric clashes within residues.
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
in a given peptide. This loss corresponds to the part with
the same residues of
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
Args:
atom14_pred_positions ([*, N, 14, 3]):
atom14_pred_positions ([*, N, 14, 3]):
Predicted positions of atoms in global prediction frame.
Predicted positions of atoms in global prediction frame.
atom14_atom_exists ([*, N, 14]):
atom14_atom_exists ([*, N, 14]):
Mask denoting whether atom at positions exists for given
Mask denoting whether atom at positions exists for given
amino acid type
amino acid type
atom14_dists_lower_bound ([*, N, 14]):
atom14_dists_lower_bound ([*, N, 14]):
Lower bound on allowed distances.
Lower bound on allowed distances.
atom14_dists_upper_bound ([*, N, 14]):
atom14_dists_upper_bound ([*, N, 14]):
Upper bound on allowed distances
Upper bound on allowed distances
tighten_bounds_for_loss ([*, N]):
tighten_bounds_for_loss ([*, N]):
Extra factor to tighten loss
Extra factor to tighten loss
Returns:
Returns:
Dict containing:
Dict containing:
* 'per_atom_loss_sum' ([*, N, 14]):
* 'per_atom_loss_sum' ([*, N, 14]):
sum of all clash losses per atom, shape
sum of all clash losses per atom, shape
* 'per_atom_clash_mask' ([*, N, 14]):
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
mask whether atom clashes with any other atom shape
"""
"""
# Compute the mask for each residue.
# Compute the mask for each residue.
dists_masks
=
(
dists_masks
=
1.0
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
1.
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
)
dists_masks
=
dists_masks
.
reshape
(
dists_masks
=
dists_masks
.
reshape
(
*
((
1
,)
*
len
(
atom14_atom_exists
.
shape
[:
-
2
])),
*
dists_masks
.
shape
*
((
1
,)
*
len
(
atom14_atom_exists
.
shape
[:
-
2
])),
*
dists_masks
.
shape
)
)
dists_masks
=
(
dists_masks
=
(
atom14_atom_exists
[...,
:,
:,
None
]
*
atom14_atom_exists
[...,
:,
:,
None
]
atom14_atom_exists
[...,
:,
None
,
:]
*
*
atom14_atom_exists
[...,
:,
None
,
:]
dists_masks
*
dists_masks
)
)
# Distance matrix
# Distance matrix
dists
=
torch
.
sqrt
(
dists
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
atom14_pred_positions
[...,
:,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
:,
None
,
:]
atom14_pred_positions
[...,
:,
None
,
:,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
:]
)
**
2
,
)
dim
=-
1
**
2
,
dim
=-
1
,
)
)
)
)
...
@@ -999,34 +960,26 @@ def within_residue_violations(
...
@@ -999,34 +960,26 @@ def within_residue_violations(
dists
-
(
atom14_dists_upper_bound
-
tighten_bounds_for_loss
)
dists
-
(
atom14_dists_upper_bound
-
tighten_bounds_for_loss
)
)
)
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
# Compute the per atom loss sum.
# Compute the per atom loss sum.
per_atom_loss_sum
=
(
per_atom_loss_sum
=
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
)
# Compute the violations mask.
# Compute the violations mask.
violations
=
(
violations
=
dists_masks
*
(
dists_masks
*
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
)
)
# Compute the per atom violations.
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
)
)
return
{
return
{
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
'
per_atom_violations
'
:
per_atom_violations
"
per_atom_violations
"
:
per_atom_violations
,
}
}
def
find_structural_violations
(
def
find_structural_violations
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
atom14_pred_positions
:
torch
.
Tensor
,
atom14_pred_positions
:
torch
.
Tensor
,
...
@@ -1035,7 +988,7 @@ def find_structural_violations(
...
@@ -1035,7 +988,7 @@ def find_structural_violations(
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes several checks for structural violations."""
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
# Compute between residue backbone violations of bonds and angles.
connection_violations
=
between_residue_bond_loss
(
connection_violations
=
between_residue_bond_loss
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_positions
=
atom14_pred_positions
,
...
@@ -1043,9 +996,9 @@ def find_structural_violations(
...
@@ -1043,9 +996,9 @@ def find_structural_violations(
residue_index
=
batch
[
"residue_index"
],
residue_index
=
batch
[
"residue_index"
],
aatype
=
batch
[
"aatype"
],
aatype
=
batch
[
"aatype"
],
tolerance_factor_soft
=
violation_tolerance_factor
,
tolerance_factor_soft
=
violation_tolerance_factor
,
tolerance_factor_hard
=
violation_tolerance_factor
tolerance_factor_hard
=
violation_tolerance_factor
,
)
)
# Compute the Van der Waals radius for every atom
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
# Shape: (N, 14).
...
@@ -1053,14 +1006,12 @@ def find_structural_violations(
...
@@ -1053,14 +1006,12 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
for
name
in
residue_constants
.
atom_types
]
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atomtype_radius
)
atom14_atom_radius
=
(
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
batch
[
"atom14_atom_exists"
]
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
)
# Compute the between residue clash loss.
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
between_residue_clashes
=
between_residue_clash_loss
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
...
@@ -1068,32 +1019,28 @@ def find_structural_violations(
...
@@ -1068,32 +1019,28 @@ def find_structural_violations(
atom14_atom_radius
=
atom14_atom_radius
,
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
residue_index
=
batch
[
"residue_index"
],
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
)
# Compute all within-residue violations (clashes,
# Compute all within-residue violations (clashes,
# bond length and angle violations).
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
clash_overlap_tolerance
,
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
violation_tolerance_factor
bond_length_tolerance_factor
=
violation_tolerance_factor
,
)
)
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
atom14_dists_lower_bound
=
(
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
])[
restype_atom14_bounds
[
"lower_bound"
]
batch
[
"aatype"
]
)[
batch
[
"aatype"
]]
]
atom14_dists_upper_bound
=
atom14_pred_positions
.
new_tensor
(
)
restype_atom14_bounds
[
"upper_bound"
]
atom14_dists_upper_bound
=
(
)[
batch
[
"aatype"
]]
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
])[
batch
[
"aatype"
]
]
)
residue_violations
=
within_residue_violations
(
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
tighten_bounds_for_loss
=
0.0
,
)
)
# Combine them to a single per-residue violation mask (used later for LDDT).
# Combine them to a single per-residue violation mask (used later for LDDT).
...
@@ -1104,49 +1051,52 @@ def find_structural_violations(
...
@@ -1104,49 +1051,52 @@ def find_structural_violations(
torch
.
max
(
torch
.
max
(
between_residue_clashes
[
"per_atom_clash_mask"
],
dim
=-
1
between_residue_clashes
[
"per_atom_clash_mask"
],
dim
=-
1
)[
0
],
)[
0
],
torch
.
max
(
torch
.
max
(
residue_violations
[
"per_atom_violations"
],
dim
=-
1
)[
0
],
residue_violations
[
"per_atom_violations"
],
dim
=-
1
],
)[
0
],
],
dim
=-
1
,
dim
=-
1
,
),
),
dim
=-
1
,
dim
=-
1
,
)[
0
]
)[
0
]
return
{
return
{
'between_residues'
:
{
"between_residues"
:
{
'bonds_c_n_loss_mean'
:
"bonds_c_n_loss_mean"
:
connection_violations
[
"c_n_loss_mean"
],
# ()
connection_violations
[
"c_n_loss_mean"
],
# ()
"angles_ca_c_n_loss_mean"
:
connection_violations
[
'angles_ca_c_n_loss_mean'
:
"ca_c_n_loss_mean"
connection_violations
[
"ca_c_n_loss_mean"
],
# ()
],
# ()
'angles_c_n_ca_loss_mean'
:
"angles_c_n_ca_loss_mean"
:
connection_violations
[
connection_violations
[
"c_n_ca_loss_mean"
],
# ()
"c_n_ca_loss_mean"
'connections_per_residue_loss_sum'
:
],
# ()
connection_violations
[
"per_residue_loss_sum"
],
# (N)
"connections_per_residue_loss_sum"
:
connection_violations
[
'connections_per_residue_violation_mask'
:
"per_residue_loss_sum"
connection_violations
[
"per_residue_violation_mask"
],
# (N)
],
# (N)
'clashes_mean_loss'
:
"connections_per_residue_violation_mask"
:
connection_violations
[
between_residue_clashes
[
"mean_loss"
],
# ()
"per_residue_violation_mask"
'clashes_per_atom_loss_sum'
:
],
# (N)
between_residue_clashes
[
"per_atom_loss_sum"
],
# (N, 14)
"clashes_mean_loss"
:
between_residue_clashes
[
"mean_loss"
],
# ()
'clashes_per_atom_clash_mask'
:
"clashes_per_atom_loss_sum"
:
between_residue_clashes
[
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
"per_atom_loss_sum"
],
# (N, 14)
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
},
},
'within_residues'
:
{
"within_residues"
:
{
'per_atom_loss_sum'
:
"per_atom_loss_sum"
:
residue_violations
[
residue_violations
[
"per_atom_loss_sum"
],
# (N, 14)
"per_atom_loss_sum"
'per_atom_violations'
:
],
# (N, 14)
residue_violations
[
"per_atom_violations"
],
# (N, 14),
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
},
},
'total_per_residue_violations_mask'
:
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
per_residue_violations_mask
,
# (N)
}
}
def
find_structural_violations_np
(
def
find_structural_violations_np
(
batch
:
Dict
[
str
,
np
.
ndarray
],
batch
:
Dict
[
str
,
np
.
ndarray
],
atom14_pred_positions
:
np
.
ndarray
,
atom14_pred_positions
:
np
.
ndarray
,
config
:
ml_collections
.
ConfigDict
config
:
ml_collections
.
ConfigDict
,
)
->
Dict
[
str
,
np
.
ndarray
]:
)
->
Dict
[
str
,
np
.
ndarray
]:
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
...
@@ -1161,17 +1111,17 @@ def find_structural_violations_np(
...
@@ -1161,17 +1111,17 @@ def find_structural_violations_np(
def
extreme_ca_ca_distance_violations
(
def
extreme_ca_ca_distance_violations
(
pred_atom_positions
:
torch
.
Tensor
,
# (N, 37(14), 3)
pred_atom_positions
:
torch
.
Tensor
,
# (N, 37(14), 3)
pred_atom_mask
:
torch
.
Tensor
,
# (N, 37(14))
pred_atom_mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
max_angstrom_tolerance
=
1.5
,
eps
=
1e-6
,
eps
=
1e-6
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Counts residues whose Ca is a large distance from its neighbour.
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
more than 'max_angstrom_tolerance' apart.
Args:
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
...
@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
...
@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
this_ca_mask
=
pred_atom_mask
[...,
:
-
1
,
1
]
this_ca_mask
=
pred_atom_mask
[...,
:
-
1
,
1
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
ca_ca_distance
=
torch
.
sqrt
(
ca_ca_distance
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
)
violations
=
(
violations
=
(
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
ca_ca_distance
-
residue_constants
.
ca_ca
)
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
mean
=
masked_mean
(
mask
,
violations
,
-
1
)
mean
=
masked_mean
(
mask
,
violations
,
-
1
)
return
mean
return
mean
...
@@ -1202,18 +1152,18 @@ def compute_violation_metrics(
...
@@ -1202,18 +1152,18 @@ def compute_violation_metrics(
atom14_pred_positions
:
torch
.
Tensor
,
# (N, 14, 3)
atom14_pred_positions
:
torch
.
Tensor
,
# (N, 14, 3)
violations
:
Dict
[
str
,
torch
.
Tensor
],
violations
:
Dict
[
str
,
torch
.
Tensor
],
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Compute several metrics to assess the structural violations."""
"""Compute several metrics to assess the structural violations."""
ret
=
{}
ret
=
{}
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
"atom14_atom_exists"
],
pred_atom_mask
=
batch
[
"atom14_atom_exists"
],
residue_index
=
batch
[
"residue_index"
]
residue_index
=
batch
[
"residue_index"
]
,
)
)
ret
[
"violations_extreme_ca_ca_distance"
]
=
extreme_ca_ca_violations
ret
[
"violations_extreme_ca_ca_distance"
]
=
extreme_ca_ca_violations
ret
[
"violations_between_residue_bond"
]
=
masked_mean
(
ret
[
"violations_between_residue_bond"
]
=
masked_mean
(
batch
[
"seq_mask"
],
batch
[
"seq_mask"
],
violations
[
"between_residues"
][
violations
[
"between_residues"
][
'
connections_per_residue_violation_mask
'
"
connections_per_residue_violation_mask
"
],
],
dim
=-
1
,
dim
=-
1
,
)
)
...
@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
...
@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
mask
=
batch
[
"seq_mask"
],
mask
=
batch
[
"seq_mask"
],
value
=
torch
.
max
(
value
=
torch
.
max
(
violations
[
"between_residues"
][
"clashes_per_atom_clash_mask"
],
violations
[
"between_residues"
][
"clashes_per_atom_clash_mask"
],
dim
=-
1
dim
=-
1
,
)[
0
],
)[
0
],
dim
=-
1
,
dim
=-
1
,
)
)
...
@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
...
@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
violations
=
tree_map
(
to_tensor
,
violations
,
np
.
ndarray
)
violations
=
tree_map
(
to_tensor
,
violations
,
np
.
ndarray
)
out
=
compute_violation_metrics
(
batch
,
atom14_pred_positions
,
violations
)
out
=
compute_violation_metrics
(
batch
,
atom14_pred_positions
,
violations
)
to_np
=
lambda
x
:
np
.
array
(
x
)
to_np
=
lambda
x
:
np
.
array
(
x
)
...
@@ -1265,15 +1214,15 @@ def violation_loss(
...
@@ -1265,15 +1214,15 @@ def violation_loss(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
loss
=
(
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
l_clash
+
l_clash
)
)
return
loss
return
loss
...
@@ -1286,12 +1235,12 @@ def compute_renamed_ground_truth(
...
@@ -1286,12 +1235,12 @@ def compute_renamed_ground_truth(
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
Find optimal renaming of ground truth based on the predicted positions.
Find optimal renaming of ground truth based on the predicted positions.
Alg. 26 "renameSymmetricGroundTruthAtoms"
Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
such that each loss moves the atoms in the same direction.
Args:
Args:
batch: Dictionary containing:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_gt_positions: Ground truth positions.
...
@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
...
@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
"""
"""
pred_dists
=
torch
.
sqrt
(
pred_dists
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
atom14_pred_positions
[...,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:]
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
atom14_gt_positions
=
batch
[
"atom14_gt_positions"
]
atom14_gt_positions
=
batch
[
"atom14_gt_positions"
]
gt_dists
=
torch
.
sqrt
(
gt_dists
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
atom14_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:]
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
atom14_alt_gt_positions
=
batch
[
"atom14_alt_gt_positions"
]
atom14_alt_gt_positions
=
batch
[
"atom14_alt_gt_positions"
]
alt_gt_dists
=
torch
.
sqrt
(
alt_gt_dists
=
torch
.
sqrt
(
eps
+
eps
torch
.
sum
(
+
torch
.
sum
(
(
(
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
atom14_gt_exists
=
batch
[
"atom14_gt_exists"
]
atom14_gt_exists
=
batch
[
"atom14_gt_exists"
]
atom14_atom_is_ambiguous
=
batch
[
"atom14_atom_is_ambiguous"
]
atom14_atom_is_ambiguous
=
batch
[
"atom14_atom_is_ambiguous"
]
mask
=
(
mask
=
(
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
]
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
atom14_gt_exists
[...,
None
,
:,
None
,
:]
*
*
atom14_gt_exists
[...,
None
,
:,
None
,
:]
(
1.
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
*
(
1.
0
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
)
)
per_res_lddt
=
torch
.
sum
(
mask
*
lddt
,
dim
=
(
-
1
,
-
2
,
-
3
))
per_res_lddt
=
torch
.
sum
(
mask
*
lddt
,
dim
=
(
-
1
,
-
2
,
-
3
))
...
@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
...
@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
type
(
fp_type
)
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
type
(
fp_type
)
renamed_atom14_gt_positions
=
(
renamed_atom14_gt_positions
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
])
*
1.0
-
alt_naming_is_better
[...,
None
,
None
]
atom14_gt_positions
+
)
*
atom14_gt_positions
+
alt_naming_is_better
[
alt_naming_is_better
[...,
None
,
None
]
*
...,
None
,
None
atom14_alt_gt_positions
]
*
atom14_alt_gt_positions
)
renamed_atom14_gt_mask
=
(
renamed_atom14_gt_mask
=
(
(
1.
-
alt_naming_is_better
[...,
None
])
*
atom14_gt_exists
+
1.0
-
alt_naming_is_better
[...,
None
]
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
)
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
)
"atom14_alt_gt_exists"
]
return
{
return
{
"alt_naming_is_better"
:
alt_naming_is_better
,
"alt_naming_is_better"
:
alt_naming_is_better
,
...
@@ -1398,10 +1350,9 @@ def experimentally_resolved_loss(
...
@@ -1398,10 +1350,9 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
<=
max_resolution
)
)
)
return
loss
return
loss
...
@@ -1409,10 +1360,9 @@ def experimentally_resolved_loss(
...
@@ -1409,10 +1360,9 @@ def experimentally_resolved_loss(
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
logits
,
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
)
)
# FP16-friendly averaging. Equivalent to:
# FP16-friendly averaging. Equivalent to:
# loss = (
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
...
@@ -1435,7 +1385,7 @@ def compute_drmsd(structure_1, structure_2):
...
@@ -1435,7 +1385,7 @@ def compute_drmsd(structure_1, structure_2):
d1
=
d1
**
2
d1
=
d1
**
2
d2
=
d2
**
2
d2
=
d2
**
2
d1
=
torch
.
sqrt
(
torch
.
sum
(
d1
,
dim
=-
1
))
d1
=
torch
.
sqrt
(
torch
.
sum
(
d1
,
dim
=-
1
))
d2
=
torch
.
sqrt
(
torch
.
sum
(
d2
,
dim
=-
1
))
d2
=
torch
.
sqrt
(
torch
.
sum
(
d2
,
dim
=-
1
))
...
@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
...
@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
""" Aggregation of the various losses described in the supplement """
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
)
:
if
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
:
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
violation
,
**
self
.
config
.
violation
,
)
)
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
batch
.
update
(
compute_renamed_ground_truth
(
batch
.
update
(
batch
,
compute_renamed_ground_truth
(
out
[
"sm"
][
"positions"
][
-
1
],
batch
,
))
out
[
"sm"
][
"positions"
][
-
1
],
)
)
loss_fns
=
{
loss_fns
=
{
"distogram"
:
"distogram"
:
lambda
:
distogram_loss
(
lambda
:
distogram_loss
(
logits
=
out
[
"distogram_logits"
],
logits
=
out
[
"distogram_logits"
],
**
{
**
batch
,
**
self
.
config
.
distogram
},
**
{
**
batch
,
),
**
self
.
config
.
distogram
},
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
),
logits
=
out
[
"experimentally_resolved_logits"
],
"experimentally_resolved"
:
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
lambda
:
experimentally_resolved_loss
(
),
logits
=
out
[
"experimentally_resolved_logits"
],
"fape"
:
lambda
:
fape_loss
(
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
out
,
),
batch
,
"fape"
:
self
.
config
.
fape
,
lambda
:
fape_loss
(
),
out
,
"lddt"
:
lambda
:
lddt_loss
(
batch
,
logits
=
out
[
"lddt_logits"
],
self
.
config
.
fape
,
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
),
**
{
**
batch
,
**
self
.
config
.
lddt
},
"lddt"
:
),
lambda
:
lddt_loss
(
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"lddt_logits"
],
logits
=
out
[
"masked_msa_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
**
{
**
batch
,
**
self
.
config
.
lddt
},
),
),
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
"masked_msa"
:
out
[
"sm"
][
"angles"
],
lambda
:
masked_msa_loss
(
out
[
"sm"
][
"unnormalized_angles"
],
logits
=
out
[
"masked_msa_logits"
],
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
**
{
**
batch
,
),
**
self
.
config
.
masked_msa
},
"violation"
:
lambda
:
violation_loss
(
),
out
[
"violation"
],
"supervised_chi"
:
**
batch
,
lambda
:
supervised_chi_loss
(
),
out
[
"sm"
][
"angles"
],
"tm"
:
lambda
:
tm_loss
(
out
[
"sm"
][
"unnormalized_angles"
],
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
}
}
cum_loss
=
0
cum_loss
=
0
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
weight
=
self
.
config
[
k
].
weight
if
(
weight
)
:
if
weight
:
#print(k)
#
print(k)
loss
=
loss_fn
()
loss
=
loss_fn
()
#print(weight * loss)
#
print(weight * loss)
cum_loss
=
cum_loss
+
weight
*
loss
cum_loss
=
cum_loss
+
weight
*
loss
#print(cum_loss)
#
print(cum_loss)
return
cum_loss
return
cum_loss
openfold/utils/tensor_utils.py
View file @
07e64267
...
@@ -49,11 +49,11 @@ def dict_multimap(fn, dicts):
...
@@ -49,11 +49,11 @@ def dict_multimap(fn, dicts):
new_dict
=
{}
new_dict
=
{}
for
k
,
v
in
first
.
items
():
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
else
:
new_dict
[
k
]
=
fn
(
all_v
)
new_dict
[
k
]
=
fn
(
all_v
)
return
new_dict
return
new_dict
...
@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
...
@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
def
dict_map
(
fn
,
dic
,
leaf_type
):
def
dict_map
(
fn
,
dic
,
leaf_type
):
new_dict
=
{}
new_dict
=
{}
for
k
,
v
in
dic
.
items
():
for
k
,
v
in
dic
.
items
():
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
else
:
else
:
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
...
@@ -92,76 +92,77 @@ def dict_map(fn, dic, leaf_type):
...
@@ -92,76 +92,77 @@ def dict_map(fn, dic, leaf_type):
def
tree_map
(
fn
,
tree
,
leaf_type
):
def
tree_map
(
fn
,
tree
,
leaf_type
):
if
(
isinstance
(
tree
,
dict
)
)
:
if
isinstance
(
tree
,
dict
):
return
dict_map
(
fn
,
tree
,
leaf_type
)
return
dict_map
(
fn
,
tree
,
leaf_type
)
elif
(
isinstance
(
tree
,
list
)
)
:
elif
isinstance
(
tree
,
list
):
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
elif
(
isinstance
(
tree
,
tuple
)
)
:
elif
isinstance
(
tree
,
tuple
):
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
elif
(
isinstance
(
tree
,
leaf_type
)
)
:
elif
isinstance
(
tree
,
leaf_type
):
return
fn
(
tree
)
return
fn
(
tree
)
else
:
else
:
print
(
type
(
tree
))
print
(
type
(
tree
))
raise
ValueError
(
"Not supported"
)
raise
ValueError
(
"Not supported"
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
chunk_layer
(
def
chunk_layer
(
layer
:
Callable
,
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
chunk_size
:
int
,
no_batch_dims
:
int
,
no_batch_dims
:
int
,
)
->
Any
:
)
->
Any
:
"""
"""
Implements the "chunking" procedure described in section 1.11.8.
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
torch.Tensor leaves.
Args:
Args:
layer:
layer:
The layer to be applied chunk-wise
The layer to be applied chunk-wise
inputs:
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
be tensors and must share the same batch dimensions.
chunk_size:
chunk_size:
The number of sub-batches per chunk. If multiple batch
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
no_batch_dims:
How many of the initial dimensions of each input tensor can
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
be considered batch dimensions.
Returns:
Returns:
The reassembled output of the layer on the inputs.
The reassembled output of the layer on the inputs.
"""
"""
if
(
not
(
len
(
inputs
)
>
0
)
)
:
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
raise
ValueError
(
"Must provide at least one input"
)
def
fetch_dims
(
tree
):
def
fetch_dims
(
tree
):
shapes
=
[]
shapes
=
[]
tree_type
=
type
(
tree
)
tree_type
=
type
(
tree
)
if
(
tree_type
is
dict
)
:
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
for
v
in
tree
.
values
():
shapes
.
extend
(
fetch_dims
(
v
))
shapes
.
extend
(
fetch_dims
(
v
))
elif
(
tree_type
is
list
or
tree_type
is
tuple
)
:
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
for
t
in
tree
:
shapes
.
extend
(
fetch_dims
(
t
))
shapes
.
extend
(
fetch_dims
(
t
))
elif
(
tree_type
is
torch
.
Tensor
)
:
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
shapes
.
append
(
tree
.
shape
)
else
:
else
:
raise
ValueError
(
"Not supported"
)
raise
ValueError
(
"Not supported"
)
return
shapes
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
prep_inputs
(
t
):
def
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
# TODO: make this more memory efficient. This sucks
if
(
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
)
:
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
return
t
...
@@ -172,40 +173,42 @@ def chunk_layer(
...
@@ -172,40 +173,42 @@ def chunk_layer(
for
d
in
orig_batch_dims
:
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
flat_batch_dim
*=
d
no_chunks
=
(
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
flat_batch_dim
%
chunk_size
!=
0
)
)
i
=
0
i
=
0
out
=
None
out
=
None
for
_
in
range
(
no_chunks
):
for
_
in
range
(
no_chunks
):
# Chunk the input
# Chunk the input
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
# Run the layer on the chunk
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
# Allocate space for the output
if
(
out
is
None
)
:
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
out_type
=
type
(
output_chunk
)
if
(
out_type
is
dict
):
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
for
k
,
v
in
d1
.
items
():
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
assign
(
v
,
d2
[
k
])
else
:
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
assign
(
out
,
output_chunk
)
elif
(
out_type
is
tuple
)
:
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
(
out_type
is
torch
.
Tensor
)
:
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
else
:
raise
ValueError
(
"Not supported"
)
raise
ValueError
(
"Not supported"
)
...
@@ -214,4 +217,4 @@ def chunk_layer(
...
@@ -214,4 +217,4 @@ def chunk_layer(
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
return
out
tests/compare_utils.py
View file @
07e64267
...
@@ -15,7 +15,7 @@ from openfold.utils.import_weights import import_jax_weights_
...
@@ -15,7 +15,7 @@ from openfold.utils.import_weights import import_jax_weights_
from
tests.config
import
consts
from
tests.config
import
consts
# Give JAX some GPU memory discipline
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
# forces it to proactively free memory that it allocates)
os
.
environ
[
"XLA_PYTHON_CLIENT_ALLOCATOR"
]
=
"platform"
os
.
environ
[
"XLA_PYTHON_CLIENT_ALLOCATOR"
]
=
"platform"
os
.
environ
[
"JAX_PLATFORM_NAME"
]
=
"gpu"
os
.
environ
[
"JAX_PLATFORM_NAME"
]
=
"gpu"
...
@@ -30,17 +30,15 @@ def skip_unless_alphafold_installed():
...
@@ -30,17 +30,15 @@ def skip_unless_alphafold_installed():
def
import_alphafold
():
def
import_alphafold
():
"""
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
"""
if
(
"alphafold"
in
sys
.
modules
):
If AlphaFold is installed using the provided setuptools script, this
return
sys
.
modules
[
"alphafold"
]
is necessary to expose all of AlphaFold's precious insides
"""
if
"alphafold"
in
sys
.
modules
:
return
sys
.
modules
[
"alphafold"
]
module
=
importlib
.
import_module
(
"alphafold"
)
module
=
importlib
.
import_module
(
"alphafold"
)
# Forcefully import alphafold's submodules
# Forcefully import alphafold's submodules
submodules
=
pkgutil
.
walk_packages
(
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
))
module
.
__path__
,
prefix
=
(
"alphafold."
)
)
for
submodule_info
in
submodules
:
for
submodule_info
in
submodules
:
importlib
.
import_module
(
submodule_info
.
name
)
importlib
.
import_module
(
submodule_info
.
name
)
sys
.
modules
[
"alphafold"
]
=
module
sys
.
modules
[
"alphafold"
]
=
module
...
@@ -57,16 +55,18 @@ def get_alphafold_config():
...
@@ -57,16 +55,18 @@ def get_alphafold_config():
_param_path
=
"openfold/resources/params/params_model_1_ptm.npz"
_param_path
=
"openfold/resources/params/params_model_1_ptm.npz"
_model
=
None
_model
=
None
def
get_global_pretrained_openfold
():
def
get_global_pretrained_openfold
():
global
_model
global
_model
if
(
_model
is
None
)
:
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
).
model
)
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
).
model
)
_model
=
_model
.
eval
()
_model
=
_model
.
eval
()
if
(
not
os
.
path
.
exists
(
_param_path
)
)
:
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
installation script before running tests."""
)
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
_model
=
_model
.
cuda
()
_model
=
_model
.
cuda
()
...
@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
...
@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
_orig_weights
=
None
_orig_weights
=
None
def
_get_orig_weights
():
def
_get_orig_weights
():
global
_orig_weights
global
_orig_weights
if
(
_orig_weights
is
None
)
:
if
_orig_weights
is
None
:
_orig_weights
=
np
.
load
(
_param_path
)
_orig_weights
=
np
.
load
(
_param_path
)
return
_orig_weights
return
_orig_weights
...
@@ -84,22 +86,19 @@ def _get_orig_weights():
...
@@ -84,22 +86,19 @@ def _get_orig_weights():
def
_remove_key_prefix
(
d
,
prefix
):
def
_remove_key_prefix
(
d
,
prefix
):
for
k
,
v
in
list
(
d
.
items
()):
for
k
,
v
in
list
(
d
.
items
()):
if
(
k
.
startswith
(
prefix
)
)
:
if
k
.
startswith
(
prefix
):
d
.
pop
(
k
)
d
.
pop
(
k
)
d
[
k
[
len
(
prefix
):]]
=
v
d
[
k
[
len
(
prefix
)
:]]
=
v
def
fetch_alphafold_module_weights
(
weight_path
):
def
fetch_alphafold_module_weights
(
weight_path
):
orig_weights
=
_get_orig_weights
()
orig_weights
=
_get_orig_weights
()
params
=
{
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
"/"
in
weight_path
:
if
weight_path
in
k
spl
=
weight_path
.
split
(
"/"
)
}
if
(
'/'
in
weight_path
):
spl
=
weight_path
.
split
(
'/'
)
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
module_name
=
spl
[
-
1
]
module_name
=
spl
[
-
1
]
prefix
=
'/'
.
join
(
spl
[:
-
1
])
+
'/'
prefix
=
"/"
.
join
(
spl
[:
-
1
])
+
"/"
_remove_key_prefix
(
params
,
prefix
)
_remove_key_prefix
(
params
,
prefix
)
params
=
alphafold
.
model
.
utils
.
flat_params_to_haiku
(
params
)
params
=
alphafold
.
model
.
utils
.
flat_params_to_haiku
(
params
)
return
params
return
params
tests/config.py
View file @
07e64267
import
ml_collections
as
mlc
import
ml_collections
as
mlc
consts
=
mlc
.
ConfigDict
({
consts
=
mlc
.
ConfigDict
(
"batch_size"
:
2
,
{
"n_res"
:
11
,
"batch_size"
:
2
,
"n_seq"
:
13
,
"n_res"
:
11
,
"n_templ"
:
3
,
"n_seq"
:
13
,
"n_extra"
:
17
,
"n_templ"
:
3
,
"eps"
:
5e-4
,
"n_extra"
:
17
,
# For compatibility with DeepMind's pretrained weights, it's easiest for
"eps"
:
5e-4
,
# everyone if these take their real values.
# For compatibility with DeepMind's pretrained weights, it's easiest for
"c_m"
:
256
,
# everyone if these take their real values.
"c_z"
:
128
,
"c_m"
:
256
,
"c_s"
:
384
,
"c_z"
:
128
,
"c_t"
:
64
,
"c_s"
:
384
,
"c_e"
:
64
,
"c_t"
:
64
,
})
"c_e"
:
64
,
}
)
tests/data_utils.py
View file @
07e64267
...
@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
...
@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
b
.
append
(
batch_size
)
batch
=
{
batch
=
{
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
...
@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
...
@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_all_atom_masks"
:
np
.
random
.
randint
(
"template_all_atom_masks"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
b
,
n_templ
,
n
,
37
,
3
*
10
,
)
*
10
,
}
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
b
.
append
(
batch_size
)
batch
=
{
batch
=
{
"extra_msa"
:
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
np
.
int64
"extra_has_deletion"
:
),
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
"extra_deletion_value"
:
np
.
float32
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
),
"extra_msa_mask"
:
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
}
return
batch
return
batch
...
@@ -63,7 +66,9 @@ def random_affines_vector(dim):
...
@@ -63,7 +66,9 @@ def random_affines_vector(dim):
for
i
in
range
(
prod_dim
):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
return
affines
.
reshape
(
*
dim
,
7
)
return
affines
.
reshape
(
*
dim
,
7
)
...
@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
...
@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
for
i
in
range
(
prod_dim
):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
affines
[:,
3
,
3
]
=
1
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
return
affines
.
reshape
(
*
dim
,
4
,
4
)
tests/test_embedders.py
View file @
07e64267
...
@@ -24,30 +24,30 @@ from openfold.model.embedders import (
...
@@ -24,30 +24,30 @@ from openfold.model.embedders import (
class
TestInputEmbedder
(
unittest
.
TestCase
):
class
TestInputEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
tf_dim
=
2
tf_dim
=
2
msa_dim
=
3
msa_dim
=
3
c_z
=
5
c_z
=
5
c_m
=
7
c_m
=
7
relpos_k
=
11
relpos_k
=
11
b
=
13
b
=
13
n_res
=
17
n_res
=
17
n_clust
=
19
n_clust
=
19
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
class
TestRecyclingEmbedder
(
unittest
.
TestCase
):
class
TestRecyclingEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
batch_size
=
2
n
=
3
n
=
3
c_z
=
5
c_z
=
5
...
@@ -66,7 +66,7 @@ class TestRecyclingEmbedder(unittest.TestCase):
...
@@ -66,7 +66,7 @@ class TestRecyclingEmbedder(unittest.TestCase):
self
.
assertTrue
(
z
.
shape
==
(
batch_size
,
n
,
n
,
c_z
))
self
.
assertTrue
(
z
.
shape
==
(
batch_size
,
n
,
n
,
c_z
))
self
.
assertTrue
(
m_1
.
shape
==
(
batch_size
,
n
,
c_m
))
self
.
assertTrue
(
m_1
.
shape
==
(
batch_size
,
n
,
c_m
))
class
TestTemplateAngleEmbedder
(
unittest
.
TestCase
):
class
TestTemplateAngleEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
...
@@ -80,13 +80,11 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
...
@@ -80,13 +80,11 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
template_angle_dim
,
template_angle_dim
,
c_m
,
c_m
,
)
)
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
template_angle_dim
))
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
template_angle_dim
))
x
=
tae
(
x
)
x
=
tae
(
x
)
self
.
assertTrue
(
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
))
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
)
)
class
TestTemplatePairEmbedder
(
unittest
.
TestCase
):
class
TestTemplatePairEmbedder
(
unittest
.
TestCase
):
...
@@ -96,20 +94,17 @@ class TestTemplatePairEmbedder(unittest.TestCase):
...
@@ -96,20 +94,17 @@ class TestTemplatePairEmbedder(unittest.TestCase):
n_res
=
5
n_res
=
5
template_pair_dim
=
7
template_pair_dim
=
7
c_t
=
11
c_t
=
11
tpe
=
TemplatePairEmbedder
(
tpe
=
TemplatePairEmbedder
(
template_pair_dim
,
template_pair_dim
,
c_t
,
c_t
,
)
)
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
template_pair_dim
))
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
template_pair_dim
))
x
=
tpe
(
x
)
x
=
tpe
(
x
)
self
.
assertTrue
(
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_evoformer.py
View file @
07e64267
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
class
TestEvoformerStack
(
unittest
.
TestCase
):
class
TestEvoformerStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -91,56 +91,54 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -91,56 +91,54 @@ class TestEvoformerStack(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
ei
=
alphafold
.
model
.
modules
.
EvoformerIteration
(
ei
=
alphafold
.
model
.
modules
.
EvoformerIteration
(
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
return
ei
(
activations
,
masks
,
is_training
=
False
)
return
ei
(
activations
,
masks
,
is_training
=
False
)
f
=
hk
.
transform
(
run_ei
)
f
=
hk
.
transform
(
run_ei
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
activations
=
{
activations
=
{
'
msa
'
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
}
masks
=
{
masks
=
{
'
msa
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
}
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
key
=
jax
.
random
.
PRNGKey
(
42
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
params
,
key
,
activations
,
masks
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt_msa
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"msa"
]))
out_gt_msa
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"msa"
]))
out_gt_pair
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"pair"
]))
out_gt_pair
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"pair"
]))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
_mask_trans
=
False
,
_mask_trans
=
False
,
)
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
))
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
))
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
class
TestExtraMSAStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
batch_size
=
2
s_t
=
23
s_t
=
23
n_res
=
5
n_res
=
5
...
@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,))
msa_mask
=
torch
.
randint
(
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,))
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,
),
)
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,
),
)
shape_z_before
=
z
.
shape
shape_z_before
=
z
.
shape
...
@@ -191,7 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -191,7 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
class
TestMSATransition
(
unittest
.
TestCase
):
class
TestMSATransition
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
batch_size
=
2
s_t
=
3
s_t
=
3
n_r
=
5
n_r
=
5
...
@@ -214,39 +228,43 @@ class TestMSATransition(unittest.TestCase):
...
@@ -214,39 +228,43 @@ class TestMSATransition(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_trans
=
alphafold
.
model
.
modules
.
Transition
(
msa_trans
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
msa_transition
,
c_e
.
msa_transition
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
"msa_transition"
name
=
"msa_transition"
,
)
)
act
=
msa_trans
(
act
=
msa_act
,
mask
=
msa_mask
)
act
=
msa_trans
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_msa_transition
)
f
=
hk
.
transform
(
run_msa_transition
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"msa_transition"
+
"msa_transition"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
out_repro
=
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
model
.
evoformer
.
blocks
[
0
]
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
.
msa_transition
(
).
cpu
()
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_feats.py
View file @
07e64267
...
@@ -26,14 +26,14 @@ from openfold.np.residue_constants import (
...
@@ -26,14 +26,14 @@ from openfold.np.residue_constants import (
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
from
tests.data_utils
import
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -48,21 +48,21 @@ class TestFeats(unittest.TestCase):
...
@@ -48,21 +48,21 @@ class TestFeats(unittest.TestCase):
all_atom_pos
,
all_atom_pos
,
all_atom_mask
,
all_atom_mask
,
)
)
f
=
hk
.
transform
(
test_pbf
)
f
=
hk
.
transform
(
test_pbf
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
all_atom_pos
=
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_pos
=
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
))
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
))
out_gt_pos
,
out_gt_mask
=
f
.
apply
(
out_gt_pos
,
out_gt_mask
=
f
.
apply
(
{},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
{},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
)
out_gt_pos
=
torch
.
tensor
(
np
.
array
(
out_gt_pos
.
block_until_ready
()))
out_gt_pos
=
torch
.
tensor
(
np
.
array
(
out_gt_pos
.
block_until_ready
()))
out_gt_mask
=
torch
.
tensor
(
np
.
array
(
out_gt_mask
.
block_until_ready
()))
out_gt_mask
=
torch
.
tensor
(
np
.
array
(
out_gt_mask
.
block_until_ready
()))
out_repro_pos
,
out_repro_mask
=
feats
.
pseudo_beta_fn
(
out_repro_pos
,
out_repro_mask
=
feats
.
pseudo_beta_fn
(
torch
.
tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
all_atom_pos
).
cuda
(),
torch
.
tensor
(
all_atom_pos
).
cuda
(),
...
@@ -70,7 +70,7 @@ class TestFeats(unittest.TestCase):
...
@@ -70,7 +70,7 @@ class TestFeats(unittest.TestCase):
)
)
out_repro_pos
=
out_repro_pos
.
cpu
()
out_repro_pos
=
out_repro_pos
.
cpu
()
out_repro_mask
=
out_repro_mask
.
cpu
()
out_repro_mask
=
out_repro_mask
.
cpu
()
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt_pos
-
out_repro_pos
))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt_pos
-
out_repro_pos
))
<
consts
.
eps
)
)
...
@@ -82,26 +82,26 @@ class TestFeats(unittest.TestCase):
...
@@ -82,26 +82,26 @@ class TestFeats(unittest.TestCase):
def
test_atom37_to_torsion_angles_compare
(
self
):
def
test_atom37_to_torsion_angles_compare
(
self
):
def
run_test
(
aatype
,
all_atom_pos
,
all_atom_mask
):
def
run_test
(
aatype
,
all_atom_pos
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_torsion_angles
(
return
alphafold
.
model
.
all_atom
.
atom37_to_torsion_angles
(
aatype
,
aatype
,
all_atom_pos
,
all_atom_pos
,
all_atom_mask
,
all_atom_mask
,
placeholder_for_undefined
=
False
,
placeholder_for_undefined
=
False
,
)
)
f
=
hk
.
transform
(
run_test
)
f
=
hk
.
transform
(
run_test
)
n_templ
=
7
n_templ
=
7
n_res
=
13
n_res
=
13
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_templ
,
n_res
)).
astype
(
np
.
int64
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_templ
,
n_res
)).
astype
(
np
.
int64
)
all_atom_pos
=
np
.
random
.
rand
(
n_templ
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_pos
=
np
.
random
.
rand
(
n_templ
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)).
astype
(
0
,
2
,
(
n_templ
,
n_res
,
37
)
np
.
float32
)
.
astype
(
np
.
float32
)
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_repro
=
feats
.
atom37_to_torsion_angles
(
out_repro
=
feats
.
atom37_to_torsion_angles
(
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
...
@@ -110,20 +110,21 @@ class TestFeats(unittest.TestCase):
...
@@ -110,20 +110,21 @@ class TestFeats(unittest.TestCase):
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
tam
=
out_repro
[
"torsion_angles_mask"
].
cpu
()
tam
=
out_repro
[
"torsion_angles_mask"
].
cpu
()
# This function is extremely sensitive to floating point imprecisions,
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
# so it is given much greater latitude in comparison tests.
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
mean
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
))
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
)
<
0.01
)
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
))
<
0.01
)
)
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
mean
(
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
)
<
consts
.
eps
)
<
0.01
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
def
test_atom37_to_frames_compare
(
self
):
...
@@ -131,48 +132,50 @@ class TestFeats(unittest.TestCase):
...
@@ -131,48 +132,50 @@ class TestFeats(unittest.TestCase):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
aatype
,
all_atom_positions
,
all_atom_mask
)
)
f
=
hk
.
transform
(
run_atom37_to_frames
)
f
=
hk
.
transform
(
run_atom37_to_frames
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
batch
=
{
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"all_atom_positions"
:
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
np
.
float32
"all_atom_mask"
:
),
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
trans
=
flat12
[...,
9
:]
four_by_four
=
torch
.
zeros
(
*
flat12
.
shape
[:
-
1
],
4
,
4
)
four_by_four
=
torch
.
zeros
(
*
flat12
.
shape
[:
-
1
],
4
,
4
)
four_by_four
[...,
:
3
,
:
3
]
=
rot
four_by_four
[...,
:
3
,
:
3
]
=
rot
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_gt_frames"
]
out_gt
[
"rigidgroups_gt_frames"
]
)
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
data_transforms
.
atom37_to_frames
(
batch
)
out_repro
=
data_transforms
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
for
k
,
v
in
out_gt
.
items
():
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
)
...
@@ -190,56 +193,50 @@ class TestFeats(unittest.TestCase):
...
@@ -190,56 +193,50 @@ class TestFeats(unittest.TestCase):
aas
=
torch
.
stack
([
aas
for
_
in
range
(
batch_size
)])
aas
=
torch
.
stack
([
aas
for
_
in
range
(
batch_size
)])
frames
=
feats
.
torsion_angles_to_frames
(
frames
=
feats
.
torsion_angles_to_frames
(
ts
,
ts
,
angles
,
angles
,
aas
,
aas
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
torch
.
tensor
(
restype_rigid_group_default_frame
),
)
)
self
.
assertTrue
(
frames
.
shape
==
(
batch_size
,
n
,
8
))
self
.
assertTrue
(
frames
.
shape
==
(
batch_size
,
n
,
8
))
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_torsion_angles_to_frames_compare
(
self
):
def
test_torsion_angles_to_frames_compare
(
self
):
def
run_torsion_angles_to_frames
(
def
run_torsion_angles_to_frames
(
aatype
,
aatype
,
backb_to_global
,
torsion_angles_sin_cos
backb_to_global
,
torsion_angles_sin_cos
):
):
return
alphafold
.
model
.
all_atom
.
torsion_angles_to_frames
(
return
alphafold
.
model
.
all_atom
.
torsion_angles_to_frames
(
aatype
,
aatype
,
backb_to_global
,
backb_to_global
,
torsion_angles_sin_cos
,
torsion_angles_sin_cos
,
)
)
f
=
hk
.
transform
(
run_torsion_angles_to_frames
)
f
=
hk
.
transform
(
run_torsion_angles_to_frames
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
{},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out
=
feats
.
torsion_angles_to_frames
(
out
=
feats
.
torsion_angles_to_frames
(
transformations
.
cuda
(),
transformations
.
cuda
(),
torch
.
as_tensor
(
torsion_angles_sin_cos
).
cuda
(),
torch
.
as_tensor
(
torsion_angles_sin_cos
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
)
)
# Convert the Rigids to 4x4 transformation tensors
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
)
)
trans_gt
=
list
(
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
)
...
@@ -250,9 +247,9 @@ class TestFeats(unittest.TestCase):
...
@@ -250,9 +247,9 @@ class TestFeats(unittest.TestCase):
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
transforms_repro
=
out
.
to_4x4
().
cpu
()
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
)
)
...
@@ -275,7 +272,7 @@ class TestFeats(unittest.TestCase):
...
@@ -275,7 +272,7 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_mask
),
torch
.
tensor
(
restype_atom14_mask
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
)
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
...
@@ -285,34 +282,32 @@ class TestFeats(unittest.TestCase):
...
@@ -285,34 +282,32 @@ class TestFeats(unittest.TestCase):
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
aatype
,
affines
)
)
f
=
hk
.
transform
(
run_f
)
f
=
hk
.
transform
(
run_f
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
{},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
transformations
.
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
torch
.
tensor
(
restype_atom14_to_rigid_group
).
cuda
(),
torch
.
tensor
(
restype_atom14_to_rigid_group
).
cuda
(),
torch
.
tensor
(
restype_atom14_mask
).
cuda
(),
torch
.
tensor
(
restype_atom14_mask
).
cuda
(),
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_import_weights.py
View file @
07e64267
...
@@ -24,13 +24,14 @@ from openfold.utils.import_weights import import_jax_weights_
...
@@ -24,13 +24,14 @@ from openfold.utils.import_weights import import_jax_weights_
class
TestImportWeights
(
unittest
.
TestCase
):
class
TestImportWeights
(
unittest
.
TestCase
):
def
test_import_jax_weights_
(
self
):
def
test_import_jax_weights_
(
self
):
npz_path
=
"openfold/resources/params/params_model_1_ptm.npz"
npz_path
=
"openfold/resources/params/params_model_1_ptm.npz"
c
=
model_config
(
"model_1_ptm"
)
c
=
model_config
(
"model_1_ptm"
)
c
.
globals
.
blocks_per_ckpt
=
None
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
.
model
)
model
=
AlphaFold
(
c
.
model
)
import_jax_weights_
(
import_jax_weights_
(
model
,
npz_path
,
model
,
npz_path
,
)
)
data
=
np
.
load
(
npz_path
)
data
=
np
.
load
(
npz_path
)
...
@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
...
@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
test_pairs
=
[
test_pairs
=
[
# Normal linear weight
# Normal linear weight
(
torch
.
as_tensor
(
(
data
[
prefix
+
"structure_module/initial_projection//weights"
]
torch
.
as_tensor
(
).
transpose
(
-
1
,
-
2
),
data
[
model
.
structure_module
.
linear_in
.
weight
),
prefix
+
"structure_module/initial_projection//weights"
]
).
transpose
(
-
1
,
-
2
),
model
.
structure_module
.
linear_in
.
weight
,
),
# Normal layer norm param
# Normal layer norm param
(
torch
.
as_tensor
(
(
data
[
prefix
+
"evoformer/prev_pair_norm//offset"
],
torch
.
as_tensor
(
),
data
[
prefix
+
"evoformer/prev_pair_norm//offset"
],
model
.
recycling_embedder
.
layer_norm_z
.
bias
),
),
model
.
recycling_embedder
.
layer_norm_z
.
bias
,
),
# From a stack
# From a stack
(
torch
.
as_tensor
(
data
[
(
prefix
+
(
torch
.
as_tensor
(
"evoformer/evoformer_iteration/outer_product_mean/"
data
[
"left_projection//weights"
prefix
)
+
(
][
1
].
transpose
(
-
1
,
-
2
)),
"evoformer/evoformer_iteration/outer_product_mean/"
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,),
"left_projection//weights"
)
][
1
].
transpose
(
-
1
,
-
2
)
),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
]
]
for
w_alpha
,
w_repro
in
test_pairs
:
for
w_alpha
,
w_repro
in
test_pairs
:
...
...
tests/test_loss.py
View file @
07e64267
...
@@ -41,15 +41,15 @@ from openfold.utils.loss import (
...
@@ -41,15 +41,15 @@ from openfold.utils.loss import (
tm_loss
,
tm_loss
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
dict_multimap
,
dict_multimap
,
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -99,12 +99,19 @@ class TestLoss(unittest.TestCase):
...
@@ -99,12 +99,19 @@ class TestLoss(unittest.TestCase):
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,))
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,
),
)
between_residue_bond_loss
(
between_residue_bond_loss
(
pred_pos
,
pred_pos
,
pred_atom_mask
,
pred_atom_mask
,
residue_index
,
residue_index
,
aatype
,
aatype
,
)
)
...
@@ -117,27 +124,26 @@ class TestLoss(unittest.TestCase):
...
@@ -117,27 +124,26 @@ class TestLoss(unittest.TestCase):
residue_index
,
residue_index
,
aatype
,
aatype
,
)
)
f
=
hk
.
transform
(
run_brbl
)
f
=
hk
.
transform
(
run_brbl
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
0
,
2
,
(
n_res
,
14
)
).
astype
(
np
.
float32
)
residue_index
=
np
.
arange
(
n_res
)
residue_index
=
np
.
arange
(
n_res
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
{},
None
,
{},
pred_pos
,
None
,
pred_atom_mask
,
pred_pos
,
pred_atom_mask
,
residue_index
,
residue_index
,
aatype
,
aatype
,
)
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_bond_loss
(
out_repro
=
between_residue_bond_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_atom_mask
).
cuda
(),
torch
.
tensor
(
pred_atom_mask
).
cuda
(),
...
@@ -145,13 +151,12 @@ class TestLoss(unittest.TestCase):
...
@@ -145,13 +151,12 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
aatype
).
cuda
(),
)
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
)
def
test_run_between_residue_clash_loss
(
self
):
def
test_run_between_residue_clash_loss
(
self
):
bs
=
consts
.
batch_size
bs
=
consts
.
batch_size
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -164,7 +169,7 @@ class TestLoss(unittest.TestCase):
...
@@ -164,7 +169,7 @@ class TestLoss(unittest.TestCase):
loss
=
between_residue_clash_loss
(
loss
=
between_residue_clash_loss
(
pred_pos
,
pred_pos
,
pred_atom_mask
,
pred_atom_mask
,
atom14_atom_radius
,
atom14_atom_radius
,
residue_index
,
residue_index
,
)
)
...
@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
...
@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,)
res_ind
=
np
.
arange
(
n_res
,
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
{},
None
,
{},
None
,
pred_pos
,
pred_pos
,
atom_exists
,
atom_exists
,
atom_radius
,
atom_radius
,
...
@@ -196,7 +204,7 @@ class TestLoss(unittest.TestCase):
...
@@ -196,7 +204,7 @@ class TestLoss(unittest.TestCase):
)
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_clash_loss
(
out_repro
=
between_residue_clash_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
...
@@ -204,7 +212,7 @@ class TestLoss(unittest.TestCase):
...
@@ -204,7 +212,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
)
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
...
@@ -221,7 +229,7 @@ class TestLoss(unittest.TestCase):
...
@@ -221,7 +229,7 @@ class TestLoss(unittest.TestCase):
}
}
pred_pos
=
torch
.
rand
(
n
,
14
,
3
)
pred_pos
=
torch
.
rand
(
n
,
14
,
3
)
config
=
{
config
=
{
"clash_overlap_tolerance"
:
1.5
,
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
...
@@ -242,50 +250,44 @@ class TestLoss(unittest.TestCase):
...
@@ -242,50 +250,44 @@ class TestLoss(unittest.TestCase):
os
.
chdir
(
cwd
)
os
.
chdir
(
cwd
)
return
loss
return
loss
f
=
hk
.
transform
(
run_fsv
)
f
=
hk
.
transform
(
run_fsv
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
batch
=
{
"atom14_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)),
"atom14_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)),
"residue_index"
:
np
.
arange
(
n_res
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"residx_atom14_to_atom37"
:
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
),
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
}
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
(
"clash_overlap_tolerance"
:
1.5
,
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
})
"violation_tolerance_factor"
:
12.0
,
}
out_gt
=
f
.
apply
(
{},
None
,
batch
,
pred_pos
,
config
)
)
out_gt
=
f
.
apply
({},
None
,
batch
,
pred_pos
,
config
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
out_repro
=
find_structural_violations
(
out_repro
=
find_structural_violations
(
batch
,
batch
,
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_pos
).
cuda
(),
**
config
,
**
config
,
)
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
def
compare
(
out
):
def
compare
(
out
):
gt
,
repro
=
out
gt
,
repro
=
out
assert
(
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
dict_multimap
(
compare
,
[
out_gt
,
out_repro
])
dict_multimap
(
compare
,
[
out_gt
,
out_repro
])
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
...
@@ -295,44 +297,45 @@ class TestLoss(unittest.TestCase):
...
@@ -295,44 +297,45 @@ class TestLoss(unittest.TestCase):
batch
,
batch
,
atom14_pred_pos
,
atom14_pred_pos
,
)
)
f
=
hk
.
transform
(
run_crgt
)
f
=
hk
.
transform
(
run_crgt
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
),
"atom14_gt_exists"
:
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
np
.
float32
"all_atom_mask"
:
),
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
"all_atom_positions"
:
np
.
float32
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
}
def
_build_extra_feats_np
():
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
batch
=
_build_extra_feats_np
()
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
array
(
x
)),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
array
(
x
)),
out_gt
)
batch
=
tree_map
(
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
out_repro
=
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
)
out_repro
=
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
in
out_repro
:
for
k
in
out_repro
:
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
...
@@ -346,84 +349,76 @@ class TestLoss(unittest.TestCase):
...
@@ -346,84 +349,76 @@ class TestLoss(unittest.TestCase):
config
.
model
.
heads
.
masked_msa
,
config
.
model
.
global_config
config
.
model
.
heads
.
masked_msa
,
config
.
model
.
global_config
)
)
return
msa_head
.
loss
(
value
,
batch
)
return
msa_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_msa_loss
)
f
=
hk
.
transform
(
run_msa_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
value
=
{
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
np
.
float32
),
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
value
[
"logits"
],
**
batch
,
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_distogram_loss_compare
(
self
):
def
test_distogram_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_distogram
=
config
.
model
.
heads
.
distogram
c_distogram
=
config
.
model
.
heads
.
distogram
def
run_distogram_loss
(
value
,
batch
):
def
run_distogram_loss
(
value
,
batch
):
dist_head
=
alphafold
.
model
.
modules
.
DistogramHead
(
dist_head
=
alphafold
.
model
.
modules
.
DistogramHead
(
c_distogram
,
config
.
model
.
global_config
c_distogram
,
config
.
model
.
global_config
)
)
return
dist_head
.
loss
(
value
,
batch
)
return
dist_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_distogram_loss
)
f
=
hk
.
transform
(
run_distogram_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
value
=
{
value
=
{
"logits"
:
np
.
random
.
rand
(
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
n_res
,
np
.
float32
n_res
,
),
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"bin_edges"
:
np
.
linspace
(
"bin_edges"
:
np
.
linspace
(
c_distogram
.
first_break
,
c_distogram
.
first_break
,
c_distogram
.
last_break
,
c_distogram
.
last_break
,
c_distogram
.
num_bins
,
c_distogram
.
num_bins
,
)
)
,
}
}
batch
=
{
batch
=
{
"pseudo_beta"
:
np
.
random
.
rand
(
n_res
,
3
).
astype
(
np
.
float32
),
"pseudo_beta"
:
np
.
random
.
rand
(
n_res
,
3
).
astype
(
np
.
float32
),
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
,
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_repro
=
distogram_loss
(
out_repro
=
distogram_loss
(
logits
=
value
[
"logits"
],
logits
=
value
[
"logits"
],
...
@@ -431,66 +426,64 @@ class TestLoss(unittest.TestCase):
...
@@ -431,66 +426,64 @@ class TestLoss(unittest.TestCase):
max_bin
=
c_distogram
.
last_break
,
max_bin
=
c_distogram
.
last_break
,
no_bins
=
c_distogram
.
num_bins
,
no_bins
=
c_distogram
.
num_bins
,
**
batch
,
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_experimentally_resolved_loss_compare
(
self
):
def
test_experimentally_resolved_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_experimentally_resolved
=
config
.
model
.
heads
.
experimentally_resolved
c_experimentally_resolved
=
config
.
model
.
heads
.
experimentally_resolved
def
run_experimentally_resolved_loss
(
value
,
batch
):
def
run_experimentally_resolved_loss
(
value
,
batch
):
er_head
=
alphafold
.
model
.
modules
.
ExperimentallyResolvedHead
(
er_head
=
alphafold
.
model
.
modules
.
ExperimentallyResolvedHead
(
c_experimentally_resolved
,
config
.
model
.
global_config
c_experimentally_resolved
,
config
.
model
.
global_config
)
)
return
er_head
.
loss
(
value
,
batch
)
return
er_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_experimentally_resolved_loss
)
f
=
hk
.
transform
(
run_experimentally_resolved_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
value
=
{
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
37
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
37
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"atom37_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"atom37_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"resolution"
:
np
.
array
(
1.0
)
"resolution"
:
np
.
array
(
1.0
)
,
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_repro
=
experimentally_resolved_loss
(
out_repro
=
experimentally_resolved_loss
(
logits
=
value
[
"logits"
],
logits
=
value
[
"logits"
],
min_resolution
=
c_experimentally_resolved
.
min_resolution
,
min_resolution
=
c_experimentally_resolved
.
min_resolution
,
max_resolution
=
c_experimentally_resolved
.
max_resolution
,
max_resolution
=
c_experimentally_resolved
.
max_resolution
,
**
batch
,
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_supervised_chi_loss_compare
(
self
):
def
test_supervised_chi_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_chi_loss
=
config
.
model
.
heads
.
structure_module
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
def
run_supervised_chi_loss
(
value
,
batch
):
ret
=
{
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.
),
"loss"
:
jax
.
numpy
.
array
(
0.
0
),
}
}
alphafold
.
model
.
folding
.
supervised_chi_loss
(
alphafold
.
model
.
folding
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
ret
,
batch
,
value
,
c_chi_loss
...
@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
...
@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
value
=
{
value
=
{
"sidechains"
:
{
"sidechains"
:
{
"angles_sin_cos"
:
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
np
.
float32
"unnormalized_angles_sin_cos"
:
),
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
}
}
}
}
...
@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
...
@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
value
=
tree_map
(
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
[
"chi_angles_sin_cos"
]
=
torch
.
stack
(
batch
[
"chi_angles_sin_cos"
]
=
torch
.
stack
(
[
[
...
@@ -539,9 +530,9 @@ class TestLoss(unittest.TestCase):
...
@@ -539,9 +530,9 @@ class TestLoss(unittest.TestCase):
out_repro
=
supervised_chi_loss
(
out_repro
=
supervised_chi_loss
(
chi_weight
=
c_chi_loss
.
chi_weight
,
chi_weight
=
c_chi_loss
.
chi_weight
,
angle_norm_weight
=
c_chi_loss
.
angle_norm_weight
,
angle_norm_weight
=
c_chi_loss
.
angle_norm_weight
,
**
{
**
batch
,
**
value
[
"sidechains"
]}
**
{
**
batch
,
**
value
[
"sidechains"
]}
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
@@ -550,111 +541,119 @@ class TestLoss(unittest.TestCase):
...
@@ -550,111 +541,119 @@ class TestLoss(unittest.TestCase):
def
test_violation_loss_compare
(
self
):
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
c_viol
=
config
.
model
.
heads
.
structure_module
def
run_viol_loss
(
batch
,
atom14_pred_pos
):
def
run_viol_loss
(
batch
,
atom14_pred_pos
):
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.
).
astype
(
np
.
float32
),
"loss"
:
np
.
array
(
0.
0
).
astype
(
np
.
float32
),
}
}
value
=
{}
value
=
{}
value
[
"violations"
]
=
(
value
[
alphafold
.
model
.
folding
.
find_structural_
violations
(
"
violations
"
batch
,
]
=
alphafold
.
model
.
folding
.
find_structural_violations
(
atom14_pred_pos
,
batch
,
c_viol
,
atom14_pred_pos
,
)
c_viol
,
)
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
alphafold
.
model
.
folding
.
structural_violation_loss
(
ret
,
batch
,
value
,
c_viol
,
ret
,
batch
,
value
,
c_viol
,
)
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_viol_loss
)
f
=
hk
.
transform
(
run_viol_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
tree_map
(
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
out_repro
=
violation_loss
(
out_repro
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
**
batch
,
**
batch
,
)
)
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_lddt_loss_compare
(
self
):
def
test_lddt_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_plddt
=
config
.
model
.
heads
.
predicted_lddt
c_plddt
=
config
.
model
.
heads
.
predicted_lddt
def
run_plddt_loss
(
value
,
batch
):
def
run_plddt_loss
(
value
,
batch
):
head
=
alphafold
.
model
.
modules
.
PredictedLDDTHead
(
head
=
alphafold
.
model
.
modules
.
PredictedLDDTHead
(
c_plddt
,
config
.
model
.
global_config
c_plddt
,
config
.
model
.
global_config
)
)
return
head
.
loss
(
value
,
batch
)
return
head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_plddt_loss
)
f
=
hk
.
transform
(
run_plddt_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
value
=
{
value
=
{
"predicted_lddt"
:
{
"predicted_lddt"
:
{
"logits"
:
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
np
.
float32
),
},
},
"structure_module"
:
{
"structure_module"
:
{
"final_atom_positions"
:
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
np
.
float32
}
),
},
}
}
batch
=
{
batch
=
{
"all_atom_positions"
:
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
np
.
float32
"all_atom_mask"
:
),
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
np
.
float32
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
[
"loss"
]))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
[
"loss"
]))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
lddt_loss
(
out_repro
=
lddt_loss
(
logits
=
value
[
"predicted_lddt"
][
"logits"
],
logits
=
value
[
"predicted_lddt"
][
"logits"
],
all_atom_pred_pos
=
value
[
"structure_module"
][
"final_atom_positions"
],
all_atom_pred_pos
=
value
[
"structure_module"
][
"final_atom_positions"
],
**
{
**
batch
,
**
c_plddt
},
**
{
**
batch
,
**
c_plddt
},
)
)
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_backbone_loss
(
self
):
def
test_backbone_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
def
run_bb_loss
(
batch
,
value
):
ret
=
{
ret
=
{
"loss"
:
np
.
array
(
0.
),
"loss"
:
np
.
array
(
0.
0
),
}
}
alphafold
.
model
.
folding
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
alphafold
.
model
.
folding
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
...
@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
...
@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
batch
=
{
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
np
.
float32
"use_clamped_fape"
:
np
.
array
(
0.
),
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
}
}
value
=
{
value
=
{
"traj"
:
random_affines_vector
((
c_sm
.
num_layer
,
n_res
,)),
"traj"
:
random_affines_vector
(
(
c_sm
.
num_layer
,
n_res
,
)
),
}
}
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
...
@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
...
@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
def
test_sidechain_loss_compare
(
self
):
def
test_sidechain_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
batch
=
{
batch
=
{
**
batch
,
**
batch
,
...
@@ -702,88 +708,94 @@ class TestLoss(unittest.TestCase):
...
@@ -702,88 +708,94 @@ class TestLoss(unittest.TestCase):
batch
[
"aatype"
],
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
batch
[
"all_atom_mask"
],
)
)
,
}
}
v
=
{}
v
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
(
v
[
"sidechains"
][
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
"frames"
value
[
"sidechains"
][
"frames"
]
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
)
value
[
"sidechains"
][
"frames"
]
)
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
value
[
"sidechains"
][
"atom_pos"
]
)
)
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
v
.
update
(
batch
,
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
atom14_pred_positions
,
batch
,
))
atom14_pred_positions
,
)
)
value
=
v
value
=
v
ret
=
alphafold
.
model
.
folding
.
sidechain_loss
(
batch
,
value
,
c_sm
)
ret
=
alphafold
.
model
.
folding
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
f
=
hk
.
transform
(
run_sidechain_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
np
.
float32
"atom14_gt_exists"
:
),
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
"all_atom_positions"
:
np
.
float32
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
),
"all_atom_mask"
:
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
}
def
_build_extra_feats_np
():
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
batch
=
_build_extra_feats_np
()
value
=
{
value
=
{
"sidechains"
:
{
"sidechains"
:
{
"frames"
:
random_affines_4x4
((
c_sm
.
num_layer
,
n_res
,
8
)),
"frames"
:
random_affines_4x4
((
c_sm
.
num_layer
,
n_res
,
8
)),
"atom_pos"
:
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
random
.
rand
(
np
.
float32
c_sm
.
num_layer
,
n_res
,
14
,
3
),
).
astype
(
np
.
float32
),
}
}
}
}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
,
atom14_pred_pos
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
atom14_pred_pos
=
to_tensor
(
atom14_pred_pos
)
atom14_pred_pos
=
to_tensor
(
atom14_pred_pos
)
batch
=
data_transforms
.
atom37_to_frames
(
batch
)
batch
=
data_transforms
.
atom37_to_frames
(
batch
)
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
))
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
))
out_repro
=
sidechain_loss
(
out_repro
=
sidechain_loss
(
sidechain_frames
=
value
[
"sidechains"
][
"frames"
],
sidechain_frames
=
value
[
"sidechains"
][
"frames"
],
sidechain_atom_pos
=
value
[
"sidechains"
][
"atom_pos"
],
sidechain_atom_pos
=
value
[
"sidechains"
][
"atom_pos"
],
**
{
**
batch
,
**
c_sm
},
**
{
**
batch
,
**
c_sm
},
)
)
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tm_loss_compare
(
self
):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
def
run_tm_loss
(
representations
,
batch
,
value
):
def
run_tm_loss
(
representations
,
batch
,
value
):
head
=
alphafold
.
model
.
modules
.
PredictedAlignedErrorHead
(
head
=
alphafold
.
model
.
modules
.
PredictedAlignedErrorHead
(
c_tm
,
config
.
model
.
global_config
c_tm
,
config
.
model
.
global_config
...
@@ -792,58 +804,58 @@ class TestLoss(unittest.TestCase):
...
@@ -792,58 +804,58 @@ class TestLoss(unittest.TestCase):
v
.
update
(
value
)
v
.
update
(
value
)
v
[
"predicted_aligned_error"
]
=
head
(
representations
,
batch
,
False
)
v
[
"predicted_aligned_error"
]
=
head
(
representations
,
batch
,
False
)
return
head
.
loss
(
v
,
batch
)[
"loss"
]
return
head
.
loss
(
v
,
batch
)[
"loss"
]
f
=
hk
.
transform
(
run_tm_loss
)
f
=
hk
.
transform
(
run_tm_loss
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
representations
=
{
representations
=
{
"pair"
:
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
np
.
float32
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
}
value
=
{
value
=
{
"structure_module"
:
{
"structure_module"
:
{
"final_affines"
:
random_affines_vector
((
n_res
,)),
"final_affines"
:
random_affines_vector
((
n_res
,)),
}
}
}
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/predicted_aligned_error_head"
"alphafold/alphafold_iteration/predicted_aligned_error_head"
)
)
out_gt
=
f
.
apply
(
params
,
None
,
representations
,
batch
,
value
)
out_gt
=
f
.
apply
(
params
,
None
,
representations
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
n
:
torch
.
tensor
(
n
).
cuda
()
to_tensor
=
lambda
n
:
torch
.
tensor
(
n
).
cuda
()
representations
=
tree_map
(
to_tensor
,
representations
,
np
.
ndarray
)
representations
=
tree_map
(
to_tensor
,
representations
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"backbone_affine_tensor"
]
=
affine_vector_to_4x4
(
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
batch
[
"backbone_affine_tensor"
]
)
)
value
[
"structure_module"
][
"final_affines"
]
=
(
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
value
[
"structure_module"
][
"final_affines"
]
)
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
out_repro
=
tm_loss
(
out_repro
=
tm_loss
(
logits
=
logits
,
logits
=
logits
,
final_affine_tensor
=
value
[
"structure_module"
][
"final_affines"
],
final_affine_tensor
=
value
[
"structure_module"
][
"final_affines"
],
**
{
**
batch
,
**
c_tm
},
**
{
**
batch
,
**
c_tm
},
)
)
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
tests/test_model.py
View file @
07e64267
...
@@ -29,7 +29,7 @@ from tests.data_utils import (
...
@@ -29,7 +29,7 @@ from tests.data_utils import (
random_extra_msa_feats
,
random_extra_msa_feats
,
)
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -43,36 +43,29 @@ class TestModel(unittest.TestCase):
...
@@ -43,36 +43,29 @@ class TestModel(unittest.TestCase):
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
).
model
c
=
model_config
(
"model_1"
).
model
c
.
no_cycles
=
2
c
.
no_cycles
=
2
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
batch
=
{}
batch
=
{}
tf
=
torch
.
randint
(
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,)
)
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
(
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
))
(
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
)
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
n_extra_seq
,
n_res
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
low
=
0
,
high
=
2
,
size
=
(
n_res
,)
).
float
()
batch
.
update
(
make_atom14_masks
(
batch
))
batch
.
update
(
make_atom14_masks
(
batch
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
...
@@ -80,7 +73,7 @@ class TestModel(unittest.TestCase):
...
@@ -80,7 +73,7 @@ class TestModel(unittest.TestCase):
)
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out
=
model
(
batch
)
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
...
@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
...
@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
return
model
(
return
model
(
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
)
)
f
=
hk
.
transform
(
run_alphafold
)
f
=
hk
.
transform
(
run_alphafold
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
''
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
batch
=
pickle
.
load
(
fp
)
...
@@ -107,14 +102,14 @@ class TestModel(unittest.TestCase):
...
@@ -107,14 +102,14 @@ class TestModel(unittest.TestCase):
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
{
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()
}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
# Move the recycling dimension to the end
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
...
@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
...
@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
.
squeeze
(
0
)
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
tests/test_msa.py
View file @
07e64267
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
class
TestMSARowAttentionWithPairBias
(
unittest
.
TestCase
):
class
TestMSARowAttentionWithPairBias
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
52
c
=
52
no_heads
=
4
no_heads
=
4
chunk_size
=
None
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
...
@@ -58,29 +58,26 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -58,29 +58,26 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_row
=
alphafold
.
model
.
modules
.
MSARowAttentionWithPairBias
(
msa_row
=
alphafold
.
model
.
modules
.
MSARowAttentionWithPairBias
(
c_e
.
msa_row_attention_with_pair_bias
,
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
)
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
)
return
act
return
act
f
=
hk
.
transform
(
run_msa_row_att
)
f
=
hk
.
transform
(
run_msa_row_att
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
np
.
float32
)
.
astype
(
np
.
float32
)
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"msa_row_attention"
+
"msa_row_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
...
@@ -90,17 +87,21 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -90,17 +87,21 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
out_repro
=
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
model
.
evoformer
.
blocks
[
0
]
torch
.
as_tensor
(
pair_act
).
cuda
(),
.
msa_att_row
(
torch
.
as_tensor
(
msa_mask
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
).
cpu
()
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -124,47 +125,46 @@ class TestMSAColumnAttention(unittest.TestCase):
...
@@ -124,47 +125,46 @@ class TestMSAColumnAttention(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnAttention
(
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnAttention
(
c_e
.
msa_column_attention
,
c_e
.
msa_column_attention
,
config
.
model
.
global_config
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_msa_col_att
)
f
=
hk
.
transform
(
run_msa_col_att
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
np
.
float32
)
.
astype
(
np
.
float32
)
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"msa_column_attention"
+
"msa_column_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
out_repro
=
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
model
.
evoformer
.
blocks
[
0
]
torch
.
as_tensor
(
msa_mask
).
cuda
(),
.
msa_att_col
(
).
cpu
()
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -188,40 +188,42 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
...
@@ -188,40 +188,42 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnGlobalAttention
(
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnGlobalAttention
(
c_e
.
msa_column_attention
,
c_e
.
msa_column_attention
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
"msa_column_global_attention"
name
=
"msa_column_global_attention"
,
)
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_msa_col_global_att
)
f
=
hk
.
transform
(
run_msa_col_global_att
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
c_e
=
consts
.
c_e
c_e
=
consts
.
c_e
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_e
)
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_e
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
))
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
))
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
"msa_column_global_attention"
+
"msa_column_global_attention"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
extra_msa_stack
.
stack
.
blocks
[
0
].
msa_att_col
(
out_repro
=
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
.
msa_att_col
(
).
cpu
()
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_outer_product_mean.py
View file @
07e64267
...
@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
...
@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
from
openfold.utils.tensor_utils
import
tree_map
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -40,51 +41,54 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -40,51 +41,54 @@ class TestOuterProductMean(unittest.TestCase):
m
=
opm
(
m
,
mask
)
m
=
opm
(
m
,
mask
)
self
.
assertTrue
(
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_opm_compare
(
self
):
def
test_opm_compare
(
self
):
def
run_opm
(
msa_act
,
msa_mask
):
def
run_opm
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_evo
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_evo
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
opm
=
alphafold
.
model
.
modules
.
OuterProductMean
(
opm
=
alphafold
.
model
.
modules
.
OuterProductMean
(
c_evo
.
outer_product_mean
,
c_evo
.
outer_product_mean
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
consts
.
c_z
,
consts
.
c_z
,
)
)
act
=
opm
(
act
=
msa_act
,
mask
=
msa_mask
)
act
=
opm
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_opm
)
f
=
hk
.
transform
(
run_opm
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
c_m
=
consts
.
c_m
c_m
=
consts
.
c_m
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_m
).
astype
(
np
.
float32
)
*
100
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_m
).
astype
(
np
.
float32
)
*
100
msa_mask
=
np
.
random
.
randint
(
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
np
.
float32
)
.
astype
(
np
.
float32
)
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/"
+
"alphafold/alphafold_iteration/evoformer/"
"evoformer_iteration/outer_product_mean"
+
"evoformer_iteration/outer_product_mean"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
outer_product_mean
(
out_repro
=
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
model
.
evoformer
.
blocks
[
0
]
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
.
outer_product_mean
(
).
cpu
()
torch
.
as_tensor
(
msa_act
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
# Even when correct, OPM has large, precision-related errors. It gets
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
5e-4
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
5e-4
))
...
...
tests/test_pair_transition.py
View file @
07e64267
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
class
TestPairTransition
(
unittest
.
TestCase
):
class
TestPairTransition
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
n
=
4
n
=
4
...
@@ -50,42 +50,42 @@ class TestPairTransition(unittest.TestCase):
...
@@ -50,42 +50,42 @@ class TestPairTransition(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
pt
=
alphafold
.
model
.
modules
.
Transition
(
pt
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
pair_transition
,
c_e
.
pair_transition
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
"pair_transition"
name
=
"pair_transition"
,
)
)
act
=
pt
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
pt
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_pair_transition
)
f
=
hk
.
transform
(
run_pair_transition
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
ones
((
n_res
,
n_res
)).
astype
(
np
.
float32
)
# no mask
pair_mask
=
np
.
ones
((
n_res
,
n_res
)).
astype
(
np
.
float32
)
# no mask
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"pair_transition"
+
"pair_transition"
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
pair_transition
(
out_repro
=
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
model
.
evoformer
.
blocks
[
0
]
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
.
pair_transition
(
).
cpu
()
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_structure_module.py
View file @
07e64267
...
@@ -23,7 +23,7 @@ from openfold.np.residue_constants import (
...
@@ -23,7 +23,7 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
restype_atom37_mask
,
)
)
from
openfold.model.structure_module
import
(
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModule
,
StructureModuleTransition
,
StructureModuleTransition
,
...
@@ -39,7 +39,7 @@ from tests.data_utils import (
...
@@ -39,7 +39,7 @@ from tests.data_utils import (
random_affines_4x4
,
random_affines_4x4
,
)
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
)
)
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
)
...
@@ -121,78 +119,70 @@ class TestStructureModule(unittest.TestCase):
...
@@ -121,78 +119,70 @@ class TestStructureModule(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
c_global
=
config
.
model
.
global_config
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
folding
.
StructureModule
(
c_sm
,
c_global
)
sm
=
alphafold
.
model
.
folding
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()
}
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
return
sm
(
representations
,
batch
,
is_training
=
False
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
f
=
hk
.
transform
(
run_sm
)
n_res
=
200
n_res
=
200
representations
=
{
representations
=
{
'single'
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
"single"
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
'pair'
:
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
}
batch
=
{
batch
=
{
'
seq_mask
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"
seq_mask
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
'
aatype
'
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"
aatype
"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
}
batch
[
'atom14_atom_exists'
]
=
np
.
take
(
batch
[
"atom14_atom_exists"
]
=
np
.
take
(
restype_atom14_mask
,
restype_atom14_mask
,
batch
[
"aatype"
],
axis
=
0
batch
[
'aatype'
],
axis
=
0
)
)
batch
[
'atom37_atom_exists'
]
=
np
.
take
(
batch
[
"atom37_atom_exists"
]
=
np
.
take
(
restype_atom37_mask
,
restype_atom37_mask
,
batch
[
"aatype"
],
axis
=
0
batch
[
'aatype'
],
axis
=
0
)
)
batch
.
update
(
make_atom14_masks_np
(
batch
))
batch
.
update
(
make_atom14_masks_np
(
batch
))
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module"
"alphafold/alphafold_iteration/structure_module"
)
)
key
=
jax
.
random
.
PRNGKey
(
42
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
params
,
key
,
representations
,
batch
)
out_gt
=
torch
.
as_tensor
(
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"final_atom14_positions"
].
block_until_ready
())
np
.
array
(
out_gt
[
"final_atom14_positions"
].
block_until_ready
())
)
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
(
out_repro
=
model
.
structure_module
(
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
)
)
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
# The structure module, thanks to angle normalization, is very volatile
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.01
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
batch_size
=
2
n_res
=
3
n_res
=
3
c_in
=
5
c_in
=
5
bu
=
BackboneUpdate
(
c_in
)
bu
=
BackboneUpdate
(
c_in
)
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
...
@@ -237,25 +227,25 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -237,25 +227,25 @@ class TestInvariantPointAttention(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_ipa_compare
(
self
):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
folding
.
InvariantPointAttention
(
ipa
=
alphafold
.
model
.
folding
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
attn
=
ipa
(
attn
=
ipa
(
inputs_1d
=
act
,
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
mask
=
mask
,
affine
=
affine
affine
=
affine
,
)
)
return
attn
return
attn
f
=
hk
.
transform
(
run_ipa
)
f
=
hk
.
transform
(
run_ipa
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
c_s
=
consts
.
c_s
c_s
=
consts
.
c_s
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
sample_act
=
np
.
random
.
rand
(
n_res
,
c_s
)
sample_act
=
np
.
random
.
rand
(
n_res
,
c_s
)
sample_2d
=
np
.
random
.
rand
(
n_res
,
n_res
,
c_z
)
sample_2d
=
np
.
random
.
rand
(
n_res
,
n_res
,
c_z
)
sample_mask
=
np
.
ones
((
n_res
,
1
))
sample_mask
=
np
.
ones
((
n_res
,
1
))
...
@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
+
"alphafold/alphafold_iteration/structure_module/"
"fold_iteration/invariant_point_attention"
+
"fold_iteration/invariant_point_attention"
)
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
...
@@ -282,17 +270,17 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -282,17 +270,17 @@ class TestInvariantPointAttention(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
.
ipa
(
out_repro
=
model
.
structure_module
.
ipa
(
torch
.
as_tensor
(
sample_act
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_act
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_2d
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_2d
).
float
().
cuda
(),
transformations
,
transformations
,
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
class
TestAngleResnet
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
batch_size
=
2
n
=
3
n
=
3
c_s
=
13
c_s
=
13
...
@@ -300,7 +288,7 @@ class TestAngleResnet(unittest.TestCase):
...
@@ -300,7 +288,7 @@ class TestAngleResnet(unittest.TestCase):
no_layers
=
5
no_layers
=
5
no_angles
=
7
no_angles
=
7
epsilon
=
1e-12
epsilon
=
1e-12
ar
=
AngleResnet
(
c_s
,
c_hidden
,
no_layers
,
no_angles
,
epsilon
)
ar
=
AngleResnet
(
c_s
,
c_hidden
,
no_layers
,
no_angles
,
epsilon
)
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
...
tests/test_template.py
View file @
07e64267
...
@@ -24,14 +24,14 @@ import tests.compare_utils as compare_utils
...
@@ -24,14 +24,14 @@ import tests.compare_utils as compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
from
tests.data_utils
import
random_template_feats
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
class
TestTemplatePointwiseAttention
(
unittest
.
TestCase
):
class
TestTemplatePointwiseAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
c_t
=
consts
.
c_t
c_t
=
consts
.
c_t
...
@@ -40,7 +40,7 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
...
@@ -40,7 +40,7 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
no_heads
=
13
no_heads
=
13
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
inf
=
1e7
inf
=
1e7
tpa
=
TemplatePointwiseAttention
(
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
)
)
...
@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
chunk_size
=
4
chunk_size
=
4
inf
=
1e7
inf
=
1e7
eps
=
1e-7
eps
=
1e-7
tpe
=
TemplatePairStack
(
tpe
=
TemplatePairStack
(
c_t
,
c_t
,
...
@@ -98,45 +98,47 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -98,45 +98,47 @@ class TestTemplatePairStack(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
name
=
"template_pair_stack"
,
)
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
act
=
ln
(
act
)
return
act
return
act
f
=
hk
.
transform
(
run_template_pair_stack
)
f
=
hk
.
transform
(
run_template_pair_stack
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_t
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_t
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"alphafold/alphafold_iteration/evoformer/template_embedding/"
"single_template_embedding/template_pair_stack"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
)
)
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
))
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
pair_mask
).
block_until_ready
()
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
_mask_trans
=
False
,
_mask_trans
=
False
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
@@ -146,46 +148,46 @@ class Template(unittest.TestCase):
...
@@ -146,46 +148,46 @@ class Template(unittest.TestCase):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
config
.
model
.
global_config
,
)
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
return
act
f
=
hk
.
transform
(
test_template_embedding
)
f
=
hk
.
transform
(
test_template_embedding
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding"
"alphafold/alphafold_iteration/evoformer/template_embedding"
)
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
).
block_until_ready
()
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
templ_dim
=
0
,
)
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_triangular_attention.py
View file @
07e64267
...
@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
...
@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
...
@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
no_heads
=
4
no_heads
=
4
starting
=
True
starting
=
True
tan
=
TriangleAttention
(
tan
=
TriangleAttention
(
c_z
,
c
,
no_heads
,
starting
)
c_z
,
c
,
no_heads
,
starting
)
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -53,22 +48,24 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -53,22 +48,24 @@ class TestTriangularAttention(unittest.TestCase):
def
_tri_att_compare
(
self
,
starting
=
False
):
def
_tri_att_compare
(
self
,
starting
=
False
):
name
=
(
name
=
(
"triangle_attention_"
+
"triangle_attention_"
(
"starting"
if
starting
else
"ending"
)
+
+
(
"starting"
if
starting
else
"ending"
)
"_node"
+
"_node"
)
)
def
run_tri_att
(
pair_act
,
pair_mask
):
def
run_tri_att
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_att
=
alphafold
.
model
.
modules
.
TriangleAttention
(
tri_att
=
alphafold
.
model
.
modules
.
TriangleAttention
(
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_starting_node
c_e
.
triangle_attention_ending_node
,
if
starting
else
c_e
.
triangle_attention_ending_node
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
name
,
name
=
name
,
)
)
act
=
tri_att
(
pair_act
=
pair_act
,
pair_mask
=
pair_mask
)
act
=
tri_att
(
pair_act
=
pair_act
,
pair_mask
=
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_tri_att
)
f
=
hk
.
transform
(
run_tri_att
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -78,24 +75,23 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -78,24 +75,23 @@ class TestTriangularAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_start
model
.
evoformer
.
blocks
[
0
].
tri_att_end
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -110,4 +106,4 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -110,4 +106,4 @@ class TestTriangularAttention(unittest.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_triangular_multiplicative_update.py
View file @
07e64267
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
jax
import
haiku
as
hk
import
haiku
as
hk
class
TestTriangularMultiplicativeUpdate
(
unittest
.
TestCase
):
class
TestTriangularMultiplicativeUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
outgoing
=
True
outgoing
=
True
...
@@ -50,22 +50,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -50,22 +50,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
self
.
assertTrue
(
shape_before
==
shape_after
)
def
_tri_mul_compare
(
self
,
incoming
=
False
):
def
_tri_mul_compare
(
self
,
incoming
=
False
):
name
=
(
name
=
"triangle_multiplication_"
+
(
"triangle_multiplication_"
+
"incoming"
if
incoming
else
"outgoing"
(
"incoming"
if
incoming
else
"outgoing"
)
)
)
def
run_tri_mul
(
pair_act
,
pair_mask
):
def
run_tri_mul
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_mul
=
alphafold
.
model
.
modules
.
TriangleMultiplication
(
tri_mul
=
alphafold
.
model
.
modules
.
TriangleMultiplication
(
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_incoming
c_e
.
triangle_multiplication_outgoing
,
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
name
,
name
=
name
,
)
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_tri_mul
)
f
=
hk
.
transform
(
run_tri_mul
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
...
@@ -76,24 +77,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -76,24 +77,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_utils.py
View file @
07e64267
...
@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
...
@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
chunk_layer
X_90_ROT
=
torch
.
tensor
([
X_90_ROT
=
torch
.
tensor
(
[
1
,
0
,
0
],
[
[
0
,
0
,
-
1
],
[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
-
1
],
])
[
0
,
1
,
0
],
]
X_NEG_90_ROT
=
torch
.
tensor
([
)
[
1
,
0
,
0
],
[
0
,
0
,
1
],
X_NEG_90_ROT
=
torch
.
tensor
(
[
0
,
-
1
,
0
],
[
])
[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
0
,
-
1
,
0
],
]
)
class
TestAffineT
(
unittest
.
TestCase
):
class
TestAffineT
(
unittest
.
TestCase
):
...
@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
...
@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
batch_size
=
2
transf
=
[
transf
=
[
[
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
2
],
[
0
,
0
,
-
1
,
2
],
[
0
,
1
,
0
,
3
],
[
0
,
1
,
0
,
3
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
1
],
]
]
...
@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
...
@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
true_rot
=
transf
[:
3
,
:
3
]
true_rot
=
transf
[:
3
,
:
3
]
true_trans
=
transf
[:
3
,
3
]
true_trans
=
transf
[:
3
,
3
]
transf
=
torch
.
stack
(
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
[
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from_4x4
(
transf
)
t
=
T
.
from_4x4
(
transf
)
...
@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
...
@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
batch_size
=
2
n
=
5
n
=
5
transf
=
T
(
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
))
)
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
...
@@ -88,12 +88,11 @@ class TestAffineT(unittest.TestCase):
...
@@ -88,12 +88,11 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
batch_size
=
2
n
=
5
n
=
5
transf
=
T
(
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
))
)
)
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
0
)
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
0
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
...
@@ -124,7 +123,7 @@ class TestAffineT(unittest.TestCase):
...
@@ -124,7 +123,7 @@ class TestAffineT(unittest.TestCase):
x
=
torch
.
arange
(
30
)
x
=
torch
.
arange
(
30
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
x
=
x
.
view
(
2
,
-
1
,
3
)
# [2, 10, 3]
x
=
x
.
view
(
2
,
-
1
,
3
)
# [2, 10, 3]
pts
=
t
[...,
None
].
apply
(
x
)
pts
=
t
[...,
None
].
apply
(
x
)
...
@@ -165,4 +164,4 @@ class TestAffineT(unittest.TestCase):
...
@@ -165,4 +164,4 @@ class TestAffineT(unittest.TestCase):
self
.
assertTrue
(
torch
.
all
(
chunked
[
"out"
]
==
unchunked
[
"out"
]))
self
.
assertTrue
(
torch
.
all
(
chunked
[
"out"
]
==
unchunked
[
"out"
]))
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
all
(
chunked
[
"inner"
][
"out"
]
==
unchunked
[
"inner"
][
"out"
])
torch
.
all
(
chunked
[
"inner"
][
"out"
]
==
unchunked
[
"inner"
][
"out"
])
)
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment