Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1350 additions
and
1389 deletions
+1350
-1389
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+105
-113
openfold/utils/loss.py
openfold/utils/loss.py
+414
-471
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+55
-52
tests/compare_utils.py
tests/compare_utils.py
+21
-22
tests/config.py
tests/config.py
+17
-15
tests/data_utils.py
tests/data_utils.py
+23
-17
tests/test_embedders.py
tests/test_embedders.py
+12
-17
tests/test_evoformer.py
tests/test_evoformer.py
+65
-47
tests/test_feats.py
tests/test_feats.py
+77
-82
tests/test_import_weights.py
tests/test_import_weights.py
+29
-17
tests/test_loss.py
tests/test_loss.py
+255
-243
tests/test_model.py
tests/test_model.py
+23
-29
tests/test_msa.py
tests/test_msa.py
+60
-58
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+25
-21
tests/test_pair_transition.py
tests/test_pair_transition.py
+23
-23
tests/test_structure_module.py
tests/test_structure_module.py
+54
-66
tests/test_template.py
tests/test_template.py
+34
-32
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+19
-23
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+17
-18
tests/test_utils.py
tests/test_utils.py
+22
-23
No files found.
openfold/utils/import_weights.py
View file @
07e64267
...
...
@@ -35,19 +35,16 @@ class ParamType(Enum):
LinearMHAOutputWeight
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
)
)
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
))
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
Other
=
partial
(
lambda
w
:
w
)
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
self
.
transformation
=
fn
@
dataclass
class
Param
:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
...
...
@@ -58,16 +55,17 @@ class Param:
def
_process_translations_dict
(
d
,
top_layer
=
True
):
flat
=
{}
for
k
,
v
in
d
.
items
():
if
(
type
(
v
)
==
dict
)
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
''
if
type
(
v
)
==
dict
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
""
sub_flat
=
{
(
prefix
+
'/'
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
(
prefix
+
"/"
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
}
flat
.
update
(
sub_flat
)
else
:
k
=
'/'
+
k
if
not
top_layer
else
k
k
=
"/"
+
k
if
not
top_layer
else
k
flat
[
k
]
=
v
return
flat
...
...
@@ -82,19 +80,19 @@ def stacked(param_dict_list, out=None):
"parallel" Params). There must be at least one dict
in the list.
"""
if
(
out
is
None
)
:
if
out
is
None
:
out
=
{}
template
=
param_dict_list
[
0
]
for
k
,
_
in
template
.
items
():
v
=
[
d
[
k
]
for
d
in
param_dict_list
]
if
(
type
(
v
[
0
])
is
dict
)
:
if
type
(
v
[
0
])
is
dict
:
out
[
k
]
=
{}
stacked
(
v
,
out
=
out
[
k
])
elif
(
type
(
v
[
0
])
is
Param
)
:
elif
type
(
v
[
0
])
is
Param
:
stacked_param
=
Param
(
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
stacked
=
True
stacked
=
True
,
)
out
[
k
]
=
stacked_param
...
...
@@ -107,7 +105,7 @@ def assign(translation_dict, orig_weights):
with
torch
.
no_grad
():
weights
=
torch
.
as_tensor
(
orig_weights
[
k
])
ref
,
param_type
=
param
.
param
,
param
.
param_type
if
(
param
.
stacked
)
:
if
param
.
stacked
:
weights
=
torch
.
unbind
(
weights
,
0
)
else
:
weights
=
[
weights
]
...
...
@@ -131,26 +129,15 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Some templates
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
)
)
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearBias
=
lambda
l
:
(
Param
(
l
)
)
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
)
)
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
)
)
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
)
)
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
...
...
@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"key_w"
:
LinearWeightMHA
(
att
.
linear_k
.
weight
),
"value_w"
:
LinearWeightMHA
(
att
.
linear_v
.
weight
),
"output_w"
:
Param
(
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
),
"output_b"
:
LinearBias
(
att
.
linear_o
.
bias
),
}
...
...
@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
,
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
...
...
@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
...
...
@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}
def
EvoformerBlockParams
(
b
,
is_extra_msa
=
False
):
if
(
is_extra_msa
)
:
if
is_extra_msa
:
col_att_name
=
"msa_column_global_attention"
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
...
...
@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
...
...
@@ -316,9 +306,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
},
}
}
############################
# translations dict overflow
...
...
@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks_params
=
stacked
(
[
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
]
)
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
(
[
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
]
)
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
translations
=
{
"evoformer"
:
{
...
...
@@ -346,64 +331,72 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_pair_embedder
.
linear
),
"embedding2d"
:
LinearParams
(
model
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
),
"output_layer_norm"
:
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"template_single_embedding"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
...
...
@@ -415,17 +408,16 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"model_4_ptm"
,
"model_5_ptm"
,
]
if
(
version
in
no_templ
)
:
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
if
(
"template_"
in
k
)
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
if
(
"_ptm"
in
version
)
:
if
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
# Flatten keys and insert missing key prefixes
...
...
@@ -436,10 +428,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
#print(f"Incorrect: {incorrect}")
#print(f"Missing: {missing}")
#
print(f"Incorrect: {incorrect}")
#
print(f"Missing: {missing}")
assert
(
len
(
incorrect
)
==
0
)
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
...
...
openfold/utils/loss.py
View file @
07e64267
...
...
@@ -81,7 +81,7 @@ def compute_fape(
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
eps
=
1e-8
,
)
->
torch
.
Tensor
:
# [*, N_frames, N_pts, 3]
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
...
...
@@ -91,10 +91,10 @@ def compute_fape(
target_positions
[...,
None
,
:,
:],
)
error_dist
=
torch
.
sqrt
(
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
if
(
l1_clamp_distance
is
not
None
)
:
if
l1_clamp_distance
is
not
None
:
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
...
...
@@ -111,7 +111,9 @@ def compute_fape(
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
...
...
@@ -126,8 +128,8 @@ def backbone_loss(
backbone_affine_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
loss_unit_distance
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
...
...
@@ -145,7 +147,7 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
if
(
use_clamped_fape
is
not
None
)
:
if
use_clamped_fape
is
not
None
:
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[...,
None
,
:],
...
...
@@ -158,9 +160,8 @@ def backbone_loss(
eps
=
eps
,
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
fape_loss
=
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
# Take the mean over the layer dimension
...
...
@@ -177,42 +178,31 @@ def sidechain_loss(
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
length_scale
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
length_scale
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
])
*
rigidgroups_gt_frames
+
alt_naming_is_better
[...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
)
1.0
-
alt_naming_is_better
[...,
None
,
None
,
None
]
)
*
rigidgroups_gt_frames
+
alt_naming_is_better
[
...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
renamed_atom14_gt_positions
=
renamed_atom14_gt_positions
.
view
(
*
batch_dims
,
-
1
,
3
)
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
)
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
)
fape
=
compute_fape
(
sidechain_frames
,
...
...
@@ -235,19 +225,17 @@ def fape_loss(
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
)
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
**
{
**
batch
,
**
config
.
sidechain
}
**
{
**
batch
,
**
config
.
sidechain
}
,
)
return
(
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
)
return
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
def
supervised_chi_loss
(
...
...
@@ -264,7 +252,8 @@ def supervised_chi_loss(
)
->
torch
.
Tensor
:
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
aatype
,
residue_constants
.
restype_num
+
1
,
)
chi_pi_periodic
=
torch
.
einsum
(
"...ij,jk->ik"
,
...
...
@@ -276,11 +265,9 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
sum
((
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error_shifted
=
torch
.
sum
(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
...
...
@@ -295,9 +282,9 @@ def supervised_chi_loss(
loss
=
loss
+
chi_weight
*
sq_chi_loss
angle_norm
=
torch
.
sqrt
(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
0
)
norm_error
=
norm_error
.
permute
(
*
range
(
len
(
norm_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
...
...
@@ -312,14 +299,13 @@ def supervised_chi_loss(
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_bins
=
logits
.
shape
[
-
1
]
bin_width
=
1.
/
num_bins
bin_width
=
1.
0
/
num_bins
bounds
=
torch
.
arange
(
start
=
0.5
*
bin_width
,
end
=
1.0
,
step
=
bin_width
,
device
=
logits
.
device
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
pred_lddt_ca
=
torch
.
sum
(
probs
*
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
probs
*
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
dim
=-
1
,
)
return
pred_lddt_ca
*
100
...
...
@@ -331,7 +317,7 @@ def lddt_loss(
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
cutoff
:
float
=
15.
0
,
no_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
...
...
@@ -343,47 +329,49 @@ def lddt_loss(
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
all_atom_positions
[...,
None
,
:]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
all_atom_positions
[...,
None
,
:]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
dmat_pred
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
all_atom_pred_pos
[...,
None
,
:]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
all_atom_pred_pos
[...,
None
,
:]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
(
1.0
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
)
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
score
=
(
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
)
score
=
score
*
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
norm
=
1.
0
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
score
=
score
.
detach
()
...
...
@@ -396,14 +384,12 @@ def lddt_loss(
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
))
loss
=
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
)
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
@@ -420,16 +406,17 @@ def distogram_loss(
**
kwargs
,
):
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
)
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:]
)
**
2
,
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
keepdims
=
True
,
)
true_bins
=
torch
.
sum
(
dists
>
boundaries
,
dim
=-
1
)
...
...
@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
return
(
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
bin_centers
[
-
1
]
bin_centers
[
-
1
]
,
)
...
...
@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
aligned_confidence_probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
predicted_aligned_error
,
max_predicted_aligned_error
=
(
_calculate_expected_aligned_error
(
(
predicted_aligned_error
,
max_predicted_aligned_error
,
)
=
_calculate_expected_aligned_error
(
alignment_confidence_breaks
=
boundaries
,
aligned_distance_error_probs
=
aligned_confidence_probs
)
aligned_distance_error_probs
=
aligned_confidence_probs
,
)
return
{
...
...
@@ -523,14 +508,11 @@ def compute_tm(
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
(
residue_weights
is
None
)
:
if
residue_weights
is
None
:
residue_weights
=
logits
.
new_ones
(
logits
.
shape
[
-
2
])
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
...
...
@@ -538,11 +520,11 @@ def compute_tm(
n
=
logits
.
shape
[
-
2
]
clipped_n
=
max
(
n
,
19
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
0
/
3
)
-
1.8
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
tm_per_bin
=
1.
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
tm_per_bin
=
1.
0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
...
...
@@ -573,25 +555,18 @@ def tm_loss(
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
)
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
boundaries
=
boundaries
**
2
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
)
square_mask
=
(
...
...
@@ -606,8 +581,7 @@ def tm_loss(
loss
=
loss
*
scale
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
@@ -659,42 +633,36 @@ def between_residue_bond_loss(
next_n_mask
=
pred_atom_mask
[...,
1
:,
0
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
# Compute loss for the C--N bond.
c_n_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
(
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
)
next_is_proline
=
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
gt_length
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
)
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
gt_stddev
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
)
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
(
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_loss
=
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
...
...
@@ -702,10 +670,10 @@ def between_residue_bond_loss(
# Compute loss for the angles.
ca_c_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
)
n_ca_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
)
/
ca_c_bond_length
[...,
None
]
...
...
@@ -716,31 +684,31 @@ def between_residue_bond_loss(
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
ca_c_n_cos_angle_error
=
torch
.
sqrt
(
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
)
ca_c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
(
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
ca_c_n_loss
=
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
c_n_ca_cos_angle
=
torch
.
sum
((
-
c_n_unit_vec
)
*
n_ca_unit_vec
,
dim
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
c_n_ca_cos_angle_error
=
torch
.
sqrt
(
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
))
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
)
)
c_n_ca_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
(
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_ca_loss
=
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
...
...
@@ -748,37 +716,33 @@ def between_residue_bond_loss(
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
0.5
*
(
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
)
# Compute hard violations.
violation_mask
=
torch
.
max
(
torch
.
stack
(
[
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
[
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
dim
=-
2
,
),
dim
=-
2
dim
=-
2
,
)[
0
]
violation_mask
=
torch
.
maximum
(
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
,
)
return
{
'
c_n_loss_mean
'
:
c_n_loss
,
'
ca_c_n_loss_mean
'
:
ca_c_n_loss
,
'
c_n_ca_loss_mean
'
:
c_n_ca_loss
,
'
per_residue_loss_sum
'
:
per_residue_loss_sum
,
'
per_residue_violation_mask
'
:
violation_mask
"
c_n_loss_mean
"
:
c_n_loss
,
"
ca_c_n_loss_mean
"
:
ca_c_n_loss
,
"
c_n_ca_loss_mean
"
:
c_n_ca_loss
,
"
per_residue_loss_sum
"
:
per_residue_loss_sum
,
"
per_residue_violation_mask
"
:
violation_mask
,
}
...
...
@@ -820,27 +784,30 @@ def between_residue_clash_loss(
# Create the distance matrix.
# (N, N, 14, 14)
dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
)
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask
=
(
atom14_atom_exists
[...,
:,
None
,
:,
None
]
*
atom14_atom_exists
[...,
None
,
:,
None
,
:]
atom14_atom_exists
[...,
:,
None
,
:,
None
]
*
atom14_atom_exists
[...,
None
,
:,
None
,
:]
).
type
(
fp_type
)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask
=
dists_mask
*
(
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
)
# Backbone C--N bond between subsequent residues is no clash.
...
...
@@ -860,36 +827,34 @@ def between_residue_clash_loss(
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
(
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
)
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
c_n_bonds
=
(
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
)
dists_mask
=
dists_mask
*
(
1.
-
c_n_bonds
)
dists_mask
=
dists_mask
*
(
1.
0
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"CYS"
]
cys_sg_idx
=
cys
.
index
(
'
SG
'
)
cys_sg_idx
=
cys
.
index
(
"
SG
"
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
cys_sg_idx
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
1
).
squeeze
(
-
1
)
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
disulfide_bonds
=
(
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:])
dists_mask
=
dists_mask
*
(
1.
-
disulfide_bonds
)
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:]
)
dists_mask
=
dists_mask
*
(
1.0
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound
=
dists_mask
*
(
atom14_atom_radius
[...,
:,
None
,
:,
None
]
+
atom14_atom_radius
[...,
None
,
:,
None
,
:]
atom14_atom_radius
[...,
:,
None
,
:,
None
]
+
atom14_atom_radius
[...,
None
,
:,
None
,
:]
)
# Compute the error.
...
...
@@ -900,15 +865,12 @@ def between_residue_clash_loss(
# Compute the mean loss.
# shape ()
mean_loss
=
(
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
)
mean_loss
=
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
))
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
)
# Compute the hard clash mask.
...
...
@@ -925,9 +887,9 @@ def between_residue_clash_loss(
)
return
{
'
mean_loss
'
:
mean_loss
,
# shape ()
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
# shape (N, 14)
'
per_atom_clash_mask
'
:
per_atom_clash_mask
# shape (N, 14)
"
mean_loss
"
:
mean_loss
,
# shape ()
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
# shape (N, 14)
"
per_atom_clash_mask
"
:
per_atom_clash_mask
,
# shape (N, 14)
}
...
...
@@ -967,27 +929,26 @@ def within_residue_violations(
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks
=
(
1.
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
)
dists_masks
=
1.0
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
dists_masks
=
dists_masks
.
reshape
(
*
((
1
,)
*
len
(
atom14_atom_exists
.
shape
[:
-
2
])),
*
dists_masks
.
shape
)
dists_masks
=
(
atom14_atom_exists
[...,
:,
:,
None
]
*
atom14_atom_exists
[...,
:,
None
,
:]
*
dists_masks
atom14_atom_exists
[...,
:,
:,
None
]
*
atom14_atom_exists
[...,
:,
None
,
:]
*
dists_masks
)
# Distance matrix
dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
:,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
atom14_pred_positions
[...,
:,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
...
...
@@ -1001,18 +962,11 @@ def within_residue_violations(
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
# Compute the per atom loss sum.
per_atom_loss_sum
=
(
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
)
per_atom_loss_sum
=
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
# Compute the violations mask.
violations
=
(
dists_masks
*
(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
violations
=
dists_masks
*
(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
# Compute the per atom violations.
...
...
@@ -1021,12 +975,11 @@ def within_residue_violations(
)
return
{
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
'
per_atom_violations
'
:
per_atom_violations
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
"
per_atom_violations
"
:
per_atom_violations
,
}
def
find_structural_violations
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
atom14_pred_positions
:
torch
.
Tensor
,
...
...
@@ -1043,7 +996,7 @@ def find_structural_violations(
residue_index
=
batch
[
"residue_index"
],
aatype
=
batch
[
"aatype"
],
tolerance_factor_soft
=
violation_tolerance_factor
,
tolerance_factor_hard
=
violation_tolerance_factor
tolerance_factor_hard
=
violation_tolerance_factor
,
)
# Compute the Van der Waals radius for every atom
...
...
@@ -1053,12 +1006,10 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
# Compute the between residue clash loss.
...
...
@@ -1068,32 +1019,28 @@ def find_structural_violations(
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
violation_tolerance_factor
bond_length_tolerance_factor
=
violation_tolerance_factor
,
)
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
atom14_dists_lower_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
])[
batch
[
"aatype"
]
]
)
atom14_dists_upper_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
])[
batch
[
"aatype"
]
]
)
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
]
)[
batch
[
"aatype"
]]
atom14_dists_upper_bound
=
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
]
)[
batch
[
"aatype"
]]
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
tighten_bounds_for_loss
=
0.0
,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
...
...
@@ -1104,9 +1051,7 @@ def find_structural_violations(
torch
.
max
(
between_residue_clashes
[
"per_atom_clash_mask"
],
dim
=-
1
)[
0
],
torch
.
max
(
residue_violations
[
"per_atom_violations"
],
dim
=-
1
)[
0
],
torch
.
max
(
residue_violations
[
"per_atom_violations"
],
dim
=-
1
)[
0
],
],
dim
=-
1
,
),
...
...
@@ -1114,39 +1059,44 @@ def find_structural_violations(
)[
0
]
return
{
'between_residues'
:
{
'bonds_c_n_loss_mean'
:
connection_violations
[
"c_n_loss_mean"
],
# ()
'angles_ca_c_n_loss_mean'
:
connection_violations
[
"ca_c_n_loss_mean"
],
# ()
'angles_c_n_ca_loss_mean'
:
connection_violations
[
"c_n_ca_loss_mean"
],
# ()
'connections_per_residue_loss_sum'
:
connection_violations
[
"per_residue_loss_sum"
],
# (N)
'connections_per_residue_violation_mask'
:
connection_violations
[
"per_residue_violation_mask"
],
# (N)
'clashes_mean_loss'
:
between_residue_clashes
[
"mean_loss"
],
# ()
'clashes_per_atom_loss_sum'
:
between_residue_clashes
[
"per_atom_loss_sum"
],
# (N, 14)
'clashes_per_atom_clash_mask'
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
"between_residues"
:
{
"bonds_c_n_loss_mean"
:
connection_violations
[
"c_n_loss_mean"
],
# ()
"angles_ca_c_n_loss_mean"
:
connection_violations
[
"ca_c_n_loss_mean"
],
# ()
"angles_c_n_ca_loss_mean"
:
connection_violations
[
"c_n_ca_loss_mean"
],
# ()
"connections_per_residue_loss_sum"
:
connection_violations
[
"per_residue_loss_sum"
],
# (N)
"connections_per_residue_violation_mask"
:
connection_violations
[
"per_residue_violation_mask"
],
# (N)
"clashes_mean_loss"
:
between_residue_clashes
[
"mean_loss"
],
# ()
"clashes_per_atom_loss_sum"
:
between_residue_clashes
[
"per_atom_loss_sum"
],
# (N, 14)
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
},
'within_residues'
:
{
'per_atom_loss_sum'
:
residue_violations
[
"per_atom_loss_sum"
],
# (N, 14)
'per_atom_violations'
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
"per_atom_loss_sum"
],
# (N, 14)
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
},
'total_per_residue_violations_mask'
:
per_residue_violations_mask
,
# (N)
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
def
find_structural_violations_np
(
batch
:
Dict
[
str
,
np
.
ndarray
],
atom14_pred_positions
:
np
.
ndarray
,
config
:
ml_collections
.
ConfigDict
config
:
ml_collections
.
ConfigDict
,
)
->
Dict
[
str
,
np
.
ndarray
]:
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
...
...
@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
this_ca_mask
=
pred_atom_mask
[...,
:
-
1
,
1
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
ca_ca_distance
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
violations
=
(
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
)
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
mean
=
masked_mean
(
mask
,
violations
,
-
1
)
return
mean
...
...
@@ -1207,13 +1157,13 @@ def compute_violation_metrics(
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
"atom14_atom_exists"
],
residue_index
=
batch
[
"residue_index"
]
residue_index
=
batch
[
"residue_index"
]
,
)
ret
[
"violations_extreme_ca_ca_distance"
]
=
extreme_ca_ca_violations
ret
[
"violations_between_residue_bond"
]
=
masked_mean
(
batch
[
"seq_mask"
],
violations
[
"between_residues"
][
'
connections_per_residue_violation_mask
'
"
connections_per_residue_violation_mask
"
],
dim
=-
1
,
)
...
...
@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
mask
=
batch
[
"seq_mask"
],
value
=
torch
.
max
(
violations
[
"between_residues"
][
"clashes_per_atom_clash_mask"
],
dim
=-
1
dim
=-
1
,
)[
0
],
dim
=-
1
,
)
...
...
@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
violations
=
tree_map
(
to_tensor
,
violations
,
np
.
ndarray
)
out
=
compute_violation_metrics
(
batch
,
atom14_pred_positions
,
violations
)
to_np
=
lambda
x
:
np
.
array
(
x
)
...
...
@@ -1265,15 +1214,15 @@ def violation_loss(
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
l_clash
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
l_clash
)
return
loss
...
...
@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
"""
pred_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_pred_positions
[...,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
atom14_gt_positions
=
batch
[
"atom14_gt_positions"
]
gt_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
atom14_alt_gt_positions
=
batch
[
"atom14_alt_gt_positions"
]
alt_gt_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
atom14_gt_exists
=
batch
[
"atom14_gt_exists"
]
atom14_atom_is_ambiguous
=
batch
[
"atom14_atom_is_ambiguous"
]
mask
=
(
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
,
:]
*
(
1.
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
,
:]
*
(
1.
0
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
)
per_res_lddt
=
torch
.
sum
(
mask
*
lddt
,
dim
=
(
-
1
,
-
2
,
-
3
))
...
...
@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
type
(
fp_type
)
renamed_atom14_gt_positions
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
])
*
atom14_gt_positions
+
alt_naming_is_better
[...,
None
,
None
]
*
atom14_alt_gt_positions
)
1.0
-
alt_naming_is_better
[...,
None
,
None
]
)
*
atom14_gt_positions
+
alt_naming_is_better
[
...,
None
,
None
]
*
atom14_alt_gt_positions
renamed_atom14_gt_mask
=
(
(
1.
-
alt_naming_is_better
[...,
None
])
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
)
1.0
-
alt_naming_is_better
[...,
None
]
)
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
return
{
"alt_naming_is_better"
:
alt_naming_is_better
,
...
...
@@ -1400,8 +1352,7 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
@@ -1409,8 +1360,7 @@ def experimentally_resolved_loss(
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
)
# FP16-friendly averaging. Equivalent to:
...
...
@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
class
AlphaFoldLoss
(
nn
.
Module
):
""" Aggregation of the various losses described in the supplement """
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
)
:
if
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
:
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
violation
,
)
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
))
)
)
loss_fns
=
{
"distogram"
:
lambda
:
distogram_loss
(
"distogram"
:
lambda
:
distogram_loss
(
logits
=
out
[
"distogram_logits"
],
**
{
**
batch
,
**
self
.
config
.
distogram
},
**
{
**
batch
,
**
self
.
config
.
distogram
},
),
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
logits
=
out
[
"experimentally_resolved_logits"
],
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
),
"fape"
:
lambda
:
fape_loss
(
"fape"
:
lambda
:
fape_loss
(
out
,
batch
,
self
.
config
.
fape
,
),
"lddt"
:
lambda
:
lddt_loss
(
"lddt"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
lddt
},
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
),
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
),
"violation"
:
lambda
:
violation_loss
(
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
"tm"
:
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
}
cum_loss
=
0
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
if
(
weight
)
:
#print(k)
if
weight
:
#
print(k)
loss
=
loss_fn
()
#print(weight * loss)
#
print(weight * loss)
cum_loss
=
cum_loss
+
weight
*
loss
#print(cum_loss)
#
print(cum_loss)
return
cum_loss
openfold/utils/tensor_utils.py
View file @
07e64267
...
...
@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
...
...
@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
def
dict_map
(
fn
,
dic
,
leaf_type
):
new_dict
=
{}
for
k
,
v
in
dic
.
items
():
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
else
:
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
...
...
@@ -92,18 +92,19 @@ def dict_map(fn, dic, leaf_type):
def
tree_map
(
fn
,
tree
,
leaf_type
):
if
(
isinstance
(
tree
,
dict
)
)
:
if
isinstance
(
tree
,
dict
):
return
dict_map
(
fn
,
tree
,
leaf_type
)
elif
(
isinstance
(
tree
,
list
)
)
:
elif
isinstance
(
tree
,
list
):
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
elif
(
isinstance
(
tree
,
tuple
)
)
:
elif
isinstance
(
tree
,
tuple
):
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
elif
(
isinstance
(
tree
,
leaf_type
)
)
:
elif
isinstance
(
tree
,
leaf_type
):
return
fn
(
tree
)
else
:
print
(
type
(
tree
))
raise
ValueError
(
"Not supported"
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
...
...
@@ -137,19 +138,19 @@ def chunk_layer(
Returns:
The reassembled output of the layer on the inputs.
"""
if
(
not
(
len
(
inputs
)
>
0
)
)
:
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
def
fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
(
tree_type
is
dict
)
:
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
fetch_dims
(
v
))
elif
(
tree_type
is
list
or
tree_type
is
tuple
)
:
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
fetch_dims
(
t
))
elif
(
tree_type
is
torch
.
Tensor
)
:
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
...
...
@@ -161,7 +162,7 @@ def chunk_layer(
def
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
)
:
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
...
...
@@ -172,40 +173,42 @@ def chunk_layer(
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
(
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
(
out
is
None
)
:
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
(
out_type
is
dict
):
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
(
type
(
v
)
is
dict
)
:
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
(
out_type
is
tuple
)
:
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
(
out_type
is
torch
.
Tensor
)
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
...
...
tests/compare_utils.py
View file @
07e64267
...
...
@@ -34,13 +34,11 @@ def import_alphafold():
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if
(
"alphafold"
in
sys
.
modules
)
:
if
"alphafold"
in
sys
.
modules
:
return
sys
.
modules
[
"alphafold"
]
module
=
importlib
.
import_module
(
"alphafold"
)
# Forcefully import alphafold's submodules
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
)
)
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
))
for
submodule_info
in
submodules
:
importlib
.
import_module
(
submodule_info
.
name
)
sys
.
modules
[
"alphafold"
]
=
module
...
...
@@ -57,12 +55,14 @@ def get_alphafold_config():
_param_path
=
"openfold/resources/params/params_model_1_ptm.npz"
_model
=
None
def
get_global_pretrained_openfold
():
global
_model
if
(
_model
is
None
)
:
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
).
model
)
_model
=
_model
.
eval
()
if
(
not
os
.
path
.
exists
(
_param_path
)
)
:
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
...
...
@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
_orig_weights
=
None
def
_get_orig_weights
():
global
_orig_weights
if
(
_orig_weights
is
None
)
:
if
_orig_weights
is
None
:
_orig_weights
=
np
.
load
(
_param_path
)
return
_orig_weights
...
...
@@ -84,22 +86,19 @@ def _get_orig_weights():
def
_remove_key_prefix
(
d
,
prefix
):
for
k
,
v
in
list
(
d
.
items
()):
if
(
k
.
startswith
(
prefix
)
)
:
if
k
.
startswith
(
prefix
):
d
.
pop
(
k
)
d
[
k
[
len
(
prefix
):]]
=
v
d
[
k
[
len
(
prefix
)
:]]
=
v
def
fetch_alphafold_module_weights
(
weight_path
):
orig_weights
=
_get_orig_weights
()
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
if
(
'/'
in
weight_path
):
spl
=
weight_path
.
split
(
'/'
)
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
if
"/"
in
weight_path
:
spl
=
weight_path
.
split
(
"/"
)
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
module_name
=
spl
[
-
1
]
prefix
=
'/'
.
join
(
spl
[:
-
1
])
+
'/'
prefix
=
"/"
.
join
(
spl
[:
-
1
])
+
"/"
_remove_key_prefix
(
params
,
prefix
)
params
=
alphafold
.
model
.
utils
.
flat_params_to_haiku
(
params
)
return
params
tests/config.py
View file @
07e64267
import
ml_collections
as
mlc
consts
=
mlc
.
ConfigDict
({
consts
=
mlc
.
ConfigDict
(
{
"batch_size"
:
2
,
"n_res"
:
11
,
"n_seq"
:
13
,
...
...
@@ -14,4 +15,5 @@ consts = mlc.ConfigDict({
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
})
}
)
tests/data_utils.py
View file @
07e64267
...
...
@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
...
...
@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_all_atom_masks"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
return
batch
...
...
@@ -63,7 +66,9 @@ def random_affines_vector(dim):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
return
affines
.
reshape
(
*
dim
,
7
)
...
...
@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
tests/test_embedders.py
View file @
07e64267
...
...
@@ -84,9 +84,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
template_angle_dim
))
x
=
tae
(
x
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
)
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
))
class
TestTemplatePairEmbedder
(
unittest
.
TestCase
):
...
...
@@ -105,11 +103,8 @@ class TestTemplatePairEmbedder(unittest.TestCase):
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
template_pair_dim
))
x
=
tpe
(
x
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
)
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_evoformer.py
View file @
07e64267
...
...
@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -91,7 +91,8 @@ class TestEvoformerStack(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
ei
=
alphafold
.
model
.
modules
.
EvoformerIteration
(
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
return
ei
(
activations
,
masks
,
is_training
=
False
)
f
=
hk
.
transform
(
run_ei
)
...
...
@@ -100,13 +101,13 @@ class TestEvoformerStack(unittest.TestCase):
n_seq
=
consts
.
n_seq
activations
=
{
'
msa
'
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
masks
=
{
'
msa
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
...
@@ -115,9 +116,7 @@ class TestEvoformerStack(unittest.TestCase):
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt_msa
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"msa"
]))
out_gt_pair
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"pair"
]))
...
...
@@ -134,9 +133,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
))
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
))
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
...
@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,))
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,
),
)
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,
),
)
shape_z_before
=
z
.
shape
...
...
@@ -216,7 +230,7 @@ class TestMSATransition(unittest.TestCase):
msa_trans
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
msa_transition
,
config
.
model
.
global_config
,
name
=
"msa_transition"
name
=
"msa_transition"
,
)
act
=
msa_trans
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
...
...
@@ -227,25 +241,29 @@ class TestMSATransition(unittest.TestCase):
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_feats.py
View file @
07e64267
...
...
@@ -33,7 +33,7 @@ import tests.compare_utils as compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -95,9 +95,9 @@ class TestFeats(unittest.TestCase):
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_templ
,
n_res
)).
astype
(
np
.
int64
)
all_atom_pos
=
np
.
random
.
rand
(
n_templ
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)
)
.
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
...
...
@@ -114,16 +114,17 @@ class TestFeats(unittest.TestCase):
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
)
)
<
0.01
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
))
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
))
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
)
)
<
0.01
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
...
...
@@ -138,15 +139,17 @@ class TestFeats(unittest.TestCase):
batch
=
{
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
...
...
@@ -172,7 +175,7 @@ class TestFeats(unittest.TestCase):
out_repro
=
data_transforms
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
for
k
,
v
in
out_gt
.
items
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
...
...
@@ -201,9 +204,7 @@ class TestFeats(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_torsion_angles_to_frames_compare
(
self
):
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
return
alphafold
.
model
.
all_atom
.
torsion_angles_to_frames
(
aatype
,
...
...
@@ -223,9 +224,7 @@ class TestFeats(unittest.TestCase):
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
...
...
@@ -237,9 +236,7 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
)
)
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
...
...
@@ -296,9 +293,7 @@ class TestFeats(unittest.TestCase):
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
...
...
tests/test_import_weights.py
View file @
07e64267
...
...
@@ -30,7 +30,8 @@ class TestImportWeights(unittest.TestCase):
model
=
AlphaFold
(
c
.
model
)
import_jax_weights_
(
model
,
npz_path
,
model
,
npz_path
,
)
data
=
np
.
load
(
npz_path
)
...
...
@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
test_pairs
=
[
# Normal linear weight
(
torch
.
as_tensor
(
data
[
prefix
+
"structure_module/initial_projection//weights"
]
(
torch
.
as_tensor
(
data
[
prefix
+
"structure_module/initial_projection//weights"
]
).
transpose
(
-
1
,
-
2
),
model
.
structure_module
.
linear_in
.
weight
),
model
.
structure_module
.
linear_in
.
weight
,
),
# Normal layer norm param
(
torch
.
as_tensor
(
(
torch
.
as_tensor
(
data
[
prefix
+
"evoformer/prev_pair_norm//offset"
],
),
model
.
recycling_embedder
.
layer_norm_z
.
bias
),
model
.
recycling_embedder
.
layer_norm_z
.
bias
,
),
# From a stack
(
torch
.
as_tensor
(
data
[
prefix
+
(
(
torch
.
as_tensor
(
data
[
prefix
+
(
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][
1
].
transpose
(
-
1
,
-
2
)),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,),
][
1
].
transpose
(
-
1
,
-
2
)
),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
]
for
w_alpha
,
w_repro
in
test_pairs
:
...
...
tests/test_loss.py
View file @
07e64267
...
...
@@ -49,7 +49,7 @@ import tests.compare_utils as compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -99,7 +99,14 @@ class TestLoss(unittest.TestCase):
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,))
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,
),
)
between_residue_bond_loss
(
pred_pos
,
...
...
@@ -122,14 +129,13 @@ class TestLoss(unittest.TestCase):
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
residue_index
=
np
.
arange
(
n_res
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
out_gt
=
f
.
apply
(
{},
None
,
{},
None
,
pred_pos
,
pred_atom_mask
,
residue_index
,
...
...
@@ -151,7 +157,6 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
def
test_run_between_residue_clash_loss
(
self
):
bs
=
consts
.
batch_size
n
=
consts
.
n_res
...
...
@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,)
res_ind
=
np
.
arange
(
n_res
,
)
out_gt
=
f
.
apply
(
{},
None
,
{},
None
,
pred_pos
,
atom_exists
,
atom_radius
,
...
...
@@ -242,7 +250,6 @@ class TestLoss(unittest.TestCase):
os
.
chdir
(
cwd
)
return
loss
f
=
hk
.
transform
(
run_fsv
)
n_res
=
consts
.
n_res
...
...
@@ -251,30 +258,25 @@ class TestLoss(unittest.TestCase):
"atom14_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
),
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
(
{
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
})
out_gt
=
f
.
apply
(
{},
None
,
batch
,
pred_pos
,
config
}
)
out_gt
=
f
.
apply
({},
None
,
batch
,
pred_pos
,
config
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
out_repro
=
find_structural_violations
(
batch
,
torch
.
tensor
(
pred_pos
).
cuda
(),
...
...
@@ -284,7 +286,7 @@ class TestLoss(unittest.TestCase):
def
compare
(
out
):
gt
,
repro
=
out
assert
(
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
dict_multimap
(
compare
,
[
out_gt
,
out_repro
])
...
...
@@ -304,12 +306,15 @@ class TestLoss(unittest.TestCase):
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
...
...
@@ -325,9 +330,7 @@ class TestLoss(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
array
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
out_repro
=
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
)
...
...
@@ -358,19 +361,16 @@ class TestLoss(unittest.TestCase):
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
...
...
@@ -385,6 +385,7 @@ class TestLoss(unittest.TestCase):
def
test_distogram_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_distogram
=
config
.
model
.
heads
.
distogram
def
run_distogram_loss
(
value
,
batch
):
dist_head
=
alphafold
.
model
.
modules
.
DistogramHead
(
c_distogram
,
config
.
model
.
global_config
...
...
@@ -396,33 +397,27 @@ class TestLoss(unittest.TestCase):
n_res
=
consts
.
n_res
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"bin_edges"
:
np
.
linspace
(
c_distogram
.
first_break
,
c_distogram
.
last_break
,
c_distogram
.
num_bins
,
)
)
,
}
batch
=
{
"pseudo_beta"
:
np
.
random
.
rand
(
n_res
,
3
).
astype
(
np
.
float32
),
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
,
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
distogram_loss
(
...
...
@@ -441,6 +436,7 @@ class TestLoss(unittest.TestCase):
def
test_experimentally_resolved_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_experimentally_resolved
=
config
.
model
.
heads
.
experimentally_resolved
def
run_experimentally_resolved_loss
(
value
,
batch
):
er_head
=
alphafold
.
model
.
modules
.
ExperimentallyResolvedHead
(
c_experimentally_resolved
,
config
.
model
.
global_config
...
...
@@ -458,19 +454,15 @@ class TestLoss(unittest.TestCase):
batch
=
{
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"atom37_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"resolution"
:
np
.
array
(
1.0
)
"resolution"
:
np
.
array
(
1.0
)
,
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
experimentally_resolved_loss
(
...
...
@@ -488,9 +480,10 @@ class TestLoss(unittest.TestCase):
def
test_supervised_chi_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.
),
"loss"
:
jax
.
numpy
.
array
(
0.
0
),
}
alphafold
.
model
.
folding
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
...
...
@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
value
=
{
"sidechains"
:
{
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
}
}
...
...
@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
[
"chi_angles_sin_cos"
]
=
torch
.
stack
(
[
...
...
@@ -539,7 +530,7 @@ class TestLoss(unittest.TestCase):
out_repro
=
supervised_chi_loss
(
chi_weight
=
c_chi_loss
.
chi_weight
,
angle_norm_weight
=
c_chi_loss
.
angle_norm_weight
,
**
{
**
batch
,
**
value
[
"sidechains"
]}
**
{
**
batch
,
**
value
[
"sidechains"
]}
,
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
...
...
@@ -550,20 +541,24 @@ class TestLoss(unittest.TestCase):
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
def
run_viol_loss
(
batch
,
atom14_pred_pos
):
ret
=
{
"loss"
:
np
.
array
(
0.
).
astype
(
np
.
float32
),
"loss"
:
np
.
array
(
0.
0
).
astype
(
np
.
float32
),
}
value
=
{}
value
[
"violations"
]
=
(
alphafold
.
model
.
folding
.
find_structural_violations
(
value
[
"violations"
]
=
alphafold
.
model
.
folding
.
find_structural_violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
ret
,
batch
,
value
,
c_viol
,
ret
,
batch
,
value
,
c_viol
,
)
return
ret
[
"loss"
]
...
...
@@ -577,16 +572,14 @@ class TestLoss(unittest.TestCase):
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
...
...
@@ -603,6 +596,7 @@ class TestLoss(unittest.TestCase):
def
test_lddt_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_plddt
=
config
.
model
.
heads
.
predicted_lddt
def
run_plddt_loss
(
value
,
batch
):
head
=
alphafold
.
model
.
modules
.
PredictedLDDTHead
(
c_plddt
,
config
.
model
.
global_config
...
...
@@ -615,21 +609,25 @@ class TestLoss(unittest.TestCase):
value
=
{
"predicted_lddt"
:
{
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
},
"structure_module"
:
{
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
},
}
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
...
...
@@ -652,9 +650,10 @@ class TestLoss(unittest.TestCase):
def
test_backbone_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
ret
=
{
"loss"
:
np
.
array
(
0.
),
"loss"
:
np
.
array
(
0.
0
),
}
alphafold
.
model
.
folding
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
...
...
@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.
),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
}
value
=
{
"traj"
:
random_affines_vector
((
c_sm
.
num_layer
,
n_res
,)),
"traj"
:
random_affines_vector
(
(
c_sm
.
num_layer
,
n_res
,
)
),
}
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
...
...
@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
def
test_sidechain_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
batch
=
{
**
batch
,
...
...
@@ -702,22 +708,24 @@ class TestLoss(unittest.TestCase):
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
)
)
,
}
v
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
(
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
v
[
"sidechains"
][
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
)
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
))
)
)
value
=
v
ret
=
alphafold
.
model
.
folding
.
sidechain_loss
(
batch
,
value
,
c_sm
)
...
...
@@ -730,14 +738,18 @@ class TestLoss(unittest.TestCase):
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
...
...
@@ -751,10 +763,9 @@ class TestLoss(unittest.TestCase):
value
=
{
"sidechains"
:
{
"frames"
:
random_affines_4x4
((
c_sm
.
num_layer
,
n_res
,
8
)),
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
float32
),
}
}
...
...
@@ -784,6 +795,7 @@ class TestLoss(unittest.TestCase):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
def
run_tm_loss
(
representations
,
batch
,
value
):
head
=
alphafold
.
model
.
modules
.
PredictedAlignedErrorHead
(
c_tm
,
config
.
model
.
global_config
...
...
@@ -798,15 +810,15 @@ class TestLoss(unittest.TestCase):
n_res
=
consts
.
n_res
representations
=
{
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
value
=
{
...
...
@@ -827,11 +839,11 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_affine_tensor"
]
=
(
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
batch
[
"backbone_affine_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
(
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
...
...
tests/test_model.py
View file @
07e64267
...
...
@@ -29,7 +29,7 @@ from tests.data_utils import (
random_extra_msa_feats
,
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -51,28 +51,21 @@ class TestModel(unittest.TestCase):
model
=
AlphaFold
(
c
)
batch
=
{}
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,)
)
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
(
(
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
)
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
make_atom14_masks
(
batch
))
add_recycling_dims
=
lambda
t
:
(
...
...
@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
return
model
(
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
)
f
=
hk
.
transform
(
run_alphafold
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
''
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
...
...
@@ -108,13 +103,13 @@ class TestModel(unittest.TestCase):
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()
}
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
...
...
@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
tests/test_msa.py
View file @
07e64267
...
...
@@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
c_z
=
consts
.
c_z
c
=
52
no_heads
=
4
chunk_size
=
None
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
...
...
@@ -58,12 +58,9 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_row
=
alphafold
.
model
.
modules
.
MSARowAttentionWithPairBias
(
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
)
return
act
f
=
hk
.
transform
(
run_msa_row_att
)
...
...
@@ -72,15 +69,15 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
...
...
@@ -90,11 +87,15 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -124,12 +125,9 @@ class TestMSAColumnAttention(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
c_e
.
msa_column_attention
,
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_col_att
)
...
...
@@ -138,27 +136,29 @@ class TestMSAColumnAttention(unittest.TestCase):
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -190,7 +190,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnGlobalAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
,
name
=
"msa_column_global_attention"
name
=
"msa_column_global_attention"
,
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
...
...
@@ -206,21 +206,23 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
extra_msa_stack
.
stack
.
blocks
[
0
].
msa_att_col
(
out_repro
=
(
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_outer_product_mean.py
View file @
07e64267
...
...
@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -40,7 +41,8 @@ class TestOuterProductMean(unittest.TestCase):
m
=
opm
(
m
,
mask
)
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -63,27 +65,29 @@ class TestOuterProductMean(unittest.TestCase):
c_m
=
consts
.
c_m
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_m
).
astype
(
np
.
float32
)
*
100
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
outer_product_mean
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
...
...
tests/test_pair_transition.py
View file @
07e64267
...
...
@@ -20,7 +20,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -52,7 +52,7 @@ class TestPairTransition(unittest.TestCase):
pt
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
pair_transition
,
config
.
model
.
global_config
,
name
=
"pair_transition"
name
=
"pair_transition"
,
)
act
=
pt
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
...
...
@@ -66,26 +66,26 @@ class TestPairTransition(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
pair_transition
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_structure_module.py
View file @
07e64267
...
...
@@ -39,7 +39,7 @@ from tests.data_utils import (
random_affines_4x4
,
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
)
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -121,14 +119,13 @@ class TestStructureModule(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
folding
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
...
...
@@ -136,26 +133,21 @@ class TestStructureModule(unittest.TestCase):
n_res
=
200
representations
=
{
'single'
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
'pair'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"single"
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
'
seq_mask
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
'
aatype
'
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"
seq_mask
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"
aatype
"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
batch
[
'atom14_atom_exists'
]
=
np
.
take
(
restype_atom14_mask
,
batch
[
'aatype'
],
axis
=
0
batch
[
"atom14_atom_exists"
]
=
np
.
take
(
restype_atom14_mask
,
batch
[
"aatype"
],
axis
=
0
)
batch
[
'atom37_atom_exists'
]
=
np
.
take
(
restype_atom37_mask
,
batch
[
'aatype'
],
axis
=
0
batch
[
"atom37_atom_exists"
]
=
np
.
take
(
restype_atom37_mask
,
batch
[
"aatype"
],
axis
=
0
)
batch
.
update
(
make_atom14_masks_np
(
batch
))
...
...
@@ -165,9 +157,7 @@ class TestStructureModule(unittest.TestCase):
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"final_atom14_positions"
].
block_until_ready
())
)
...
...
@@ -246,7 +236,7 @@ class TestInvariantPointAttention(unittest.TestCase):
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
affine
=
affine
,
)
return
attn
...
...
@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
+
"fold_iteration/invariant_point_attention"
"alphafold/alphafold_iteration/structure_module/"
+
"fold_iteration/invariant_point_attention"
)
out_gt
=
f
.
apply
(
...
...
tests/test_template.py
View file @
07e64267
...
...
@@ -24,7 +24,7 @@ import tests.compare_utils as compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
n_res
=
consts
.
n_res
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
eps
=
1e-7
inf
=
1e7
eps
=
1e-7
tpe
=
TemplatePairStack
(
c_t
,
...
...
@@ -100,7 +100,7 @@ class TestTemplatePairStack(unittest.TestCase):
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
...
...
@@ -117,13 +117,15 @@ class TestTemplatePairStack(unittest.TestCase):
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
)
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
))
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
pair_mask
...
...
@@ -147,7 +149,7 @@ class Template(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
config
.
model
.
global_config
,
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
...
...
@@ -176,7 +178,7 @@ class Template(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
...
...
tests/test_triangular_attention.py
View file @
07e64267
...
...
@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
no_heads
=
4
starting
=
True
tan
=
TriangleAttention
(
c_z
,
c
,
no_heads
,
starting
)
tan
=
TriangleAttention
(
c_z
,
c
,
no_heads
,
starting
)
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
...
...
@@ -53,16 +48,18 @@ class TestTriangularAttention(unittest.TestCase):
def
_tri_att_compare
(
self
,
starting
=
False
):
name
=
(
"triangle_attention_"
+
(
"starting"
if
starting
else
"ending"
)
+
"_node"
"triangle_attention_"
+
(
"starting"
if
starting
else
"ending"
)
+
"_node"
)
def
run_tri_att
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_att
=
alphafold
.
model
.
modules
.
TriangleAttention
(
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_ending_node
,
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_ending_node
,
config
.
model
.
global_config
,
name
=
name
,
)
...
...
@@ -78,20 +75,19 @@ class TestTriangularAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_multiplicative_update.py
View file @
07e64267
...
...
@@ -20,7 +20,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -50,16 +50,17 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
def
_tri_mul_compare
(
self
,
incoming
=
False
):
name
=
(
"triangle_multiplication_"
+
(
"incoming"
if
incoming
else
"outgoing"
)
name
=
"triangle_multiplication_"
+
(
"incoming"
if
incoming
else
"outgoing"
)
def
run_tri_mul
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_mul
=
alphafold
.
model
.
modules
.
TriangleMultiplication
(
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
config
.
model
.
global_config
,
name
=
name
,
)
...
...
@@ -76,20 +77,19 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_utils.py
View file @
07e64267
...
...
@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
from
openfold.utils.tensor_utils
import
chunk_layer
X_90_ROT
=
torch
.
tensor
([
X_90_ROT
=
torch
.
tensor
(
[
[
1
,
0
,
0
],
[
0
,
0
,
-
1
],
[
0
,
0
,
-
1
],
[
0
,
1
,
0
],
])
]
)
X_NEG_90_ROT
=
torch
.
tensor
([
X_NEG_90_ROT
=
torch
.
tensor
(
[
[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
0
,
-
1
,
0
],
])
[
0
,
-
1
,
0
],
]
)
class
TestAffineT
(
unittest
.
TestCase
):
...
...
@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
transf
=
[
[
1
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
2
],
[
0
,
0
,
-
1
,
2
],
[
0
,
1
,
0
,
3
],
[
0
,
0
,
0
,
1
],
]
...
...
@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
true_rot
=
transf
[:
3
,
:
3
]
true_trans
=
transf
[:
3
,
3
]
transf
=
torch
.
stack
(
[
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from_4x4
(
transf
)
...
...
@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
...
...
@@ -88,8 +88,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
)
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
0
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment