Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1350 additions
and
1389 deletions
+1350
-1389
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+105
-113
openfold/utils/loss.py
openfold/utils/loss.py
+414
-471
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+55
-52
tests/compare_utils.py
tests/compare_utils.py
+21
-22
tests/config.py
tests/config.py
+17
-15
tests/data_utils.py
tests/data_utils.py
+23
-17
tests/test_embedders.py
tests/test_embedders.py
+12
-17
tests/test_evoformer.py
tests/test_evoformer.py
+65
-47
tests/test_feats.py
tests/test_feats.py
+77
-82
tests/test_import_weights.py
tests/test_import_weights.py
+29
-17
tests/test_loss.py
tests/test_loss.py
+255
-243
tests/test_model.py
tests/test_model.py
+23
-29
tests/test_msa.py
tests/test_msa.py
+60
-58
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+25
-21
tests/test_pair_transition.py
tests/test_pair_transition.py
+23
-23
tests/test_structure_module.py
tests/test_structure_module.py
+54
-66
tests/test_template.py
tests/test_template.py
+34
-32
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+19
-23
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+17
-18
tests/test_utils.py
tests/test_utils.py
+22
-23
No files found.
openfold/utils/import_weights.py
View file @
07e64267
...
...
@@ -26,28 +26,25 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class
ParamType
(
Enum
):
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
)
LinearWeightMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearMHAOutputWeight
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
)
)
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
))
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
Other
=
partial
(
lambda
w
:
w
)
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
self
.
transformation
=
fn
@
dataclass
class
Param
:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
...
...
@@ -58,16 +55,17 @@ class Param:
def
_process_translations_dict
(
d
,
top_layer
=
True
):
flat
=
{}
for
k
,
v
in
d
.
items
():
if
(
type
(
v
)
==
dict
)
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
''
if
type
(
v
)
==
dict
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
""
sub_flat
=
{
(
prefix
+
'/'
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
(
prefix
+
"/"
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
}
flat
.
update
(
sub_flat
)
else
:
k
=
'/'
+
k
if
not
top_layer
else
k
k
=
"/"
+
k
if
not
top_layer
else
k
flat
[
k
]
=
v
return
flat
...
...
@@ -75,29 +73,29 @@ def _process_translations_dict(d, top_layer=True):
def
stacked
(
param_dict_list
,
out
=
None
):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if
(
out
is
None
)
:
if
out
is
None
:
out
=
{}
template
=
param_dict_list
[
0
]
for
k
,
_
in
template
.
items
():
v
=
[
d
[
k
]
for
d
in
param_dict_list
]
if
(
type
(
v
[
0
])
is
dict
)
:
if
type
(
v
[
0
])
is
dict
:
out
[
k
]
=
{}
stacked
(
v
,
out
=
out
[
k
])
elif
(
type
(
v
[
0
])
is
Param
)
:
elif
type
(
v
[
0
])
is
Param
:
stacked_param
=
Param
(
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
stacked
=
True
stacked
=
True
,
)
out
[
k
]
=
stacked_param
out
[
k
]
=
stacked_param
return
out
...
...
@@ -107,12 +105,12 @@ def assign(translation_dict, orig_weights):
with
torch
.
no_grad
():
weights
=
torch
.
as_tensor
(
orig_weights
[
k
])
ref
,
param_type
=
param
.
param
,
param
.
param_type
if
(
param
.
stacked
)
:
if
param
.
stacked
:
weights
=
torch
.
unbind
(
weights
,
0
)
else
:
weights
=
[
weights
]
ref
=
[
ref
]
try
:
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
for
p
,
w
in
zip
(
ref
,
weights
):
...
...
@@ -121,36 +119,25 @@ def assign(translation_dict, orig_weights):
print
(
k
)
print
(
ref
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
raise
raise
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
data
=
np
.
load
(
npz_path
)
#######################
# Some templates
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
)
)
LinearBias
=
lambda
l
:
(
Param
(
l
)
)
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
)
)
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
)
)
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
)
)
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
...
...
@@ -167,7 +154,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"key_w"
:
LinearWeightMHA
(
att
.
linear_k
.
weight
),
"value_w"
:
LinearWeightMHA
(
att
.
linear_v
.
weight
),
"output_w"
:
Param
(
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
),
"output_b"
:
LinearBias
(
att
.
linear_o
.
bias
),
}
...
...
@@ -205,7 +193,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
...
...
@@ -231,7 +219,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
,
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
...
...
@@ -247,8 +235,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
...
...
@@ -276,7 +265,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
}
def
EvoformerBlockParams
(
b
,
is_extra_msa
=
False
):
if
(
is_extra_msa
)
:
if
is_extra_msa
:
col_att_name
=
"msa_column_global_attention"
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
...
...
@@ -284,8 +273,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAAttParams
(
b
.
msa_att_col
)
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
...
...
@@ -316,10 +306,9 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
}
}
,
}
############################
# translations dict overflow
############################
...
...
@@ -330,14 +319,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks_params
=
stacked
(
[
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
]
)
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
(
[
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
]
)
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
translations
=
{
"evoformer"
:
{
...
...
@@ -346,101 +331,108 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_pair_embedder
.
linear
),
"embedding2d"
:
LinearParams
(
model
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
),
"output_layer_norm"
:
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"template_single_embedding"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
no_templ
=
[
"model_3"
,
"model_4"
,
"model_5"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_3"
,
"model_4"
,
"model_5"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_5_ptm"
,
]
if
(
version
in
no_templ
)
:
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
if
(
"template_"
in
k
)
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
if
(
"_ptm"
in
version
)
:
if
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
# Flatten keys and insert missing key prefixes
flat
=
_process_translations_dict
(
translations
)
# Sanity check
keys
=
list
(
data
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
#print(f"Incorrect: {incorrect}")
#print(f"Missing: {missing}")
#
print(f"Incorrect: {incorrect}")
#
print(f"Missing: {missing}")
assert
(
len
(
incorrect
)
==
0
)
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign
(
flat
,
data
)
openfold/utils/loss.py
View file @
07e64267
...
...
@@ -25,8 +25,8 @@ from openfold.np import residue_constants
from
openfold.utils
import
feats
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
tree_map
,
tensor_tree_map
,
masked_mean
,
permute_final_dims
,
batched_gather
,
...
...
@@ -49,9 +49,9 @@ def sigmoid_cross_entropy(logits, labels):
def
torsion_angle_loss
(
a
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
a_alt_gt
,
# [*, N, 7, 2]
a
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
a_alt_gt
,
# [*, N, 7, 2]
):
# [*, N, 7]
norm
=
torch
.
norm
(
a
,
dim
=-
1
)
...
...
@@ -81,7 +81,7 @@ def compute_fape(
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
eps
=
1e-8
,
)
->
torch
.
Tensor
:
# [*, N_frames, N_pts, 3]
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
...
...
@@ -91,10 +91,10 @@ def compute_fape(
target_positions
[...,
None
,
:,
:],
)
error_dist
=
torch
.
sqrt
(
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
if
(
l1_clamp_distance
is
not
None
)
:
if
l1_clamp_distance
is
not
None
:
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
...
...
@@ -111,7 +111,9 @@ def compute_fape(
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
...
...
@@ -126,14 +128,14 @@ def backbone_loss(
backbone_affine_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
loss_unit_distance
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[...,
None
,
:],
...
...
@@ -145,7 +147,7 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
if
(
use_clamped_fape
is
not
None
)
:
if
use_clamped_fape
is
not
None
:
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[...,
None
,
:],
...
...
@@ -158,9 +160,8 @@ def backbone_loss(
eps
=
eps
,
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
fape_loss
=
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
# Take the mean over the layer dimension
...
...
@@ -177,42 +178,31 @@ def sidechain_loss(
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
length_scale
:
float
=
10.
,
clamp_distance
:
float
=
10.
0
,
length_scale
:
float
=
10.
0
,
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
])
*
rigidgroups_gt_frames
+
alt_naming_is_better
[...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
)
1.0
-
alt_naming_is_better
[...,
None
,
None
,
None
]
)
*
rigidgroups_gt_frames
+
alt_naming_is_better
[
...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
renamed_atom14_gt_positions
=
renamed_atom14_gt_positions
.
view
(
*
batch_dims
,
-
1
,
3
)
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
)
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
)
fape
=
compute_fape
(
sidechain_frames
,
...
...
@@ -235,19 +225,17 @@ def fape_loss(
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
)
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
**
{
**
batch
,
**
config
.
sidechain
}
**
{
**
batch
,
**
config
.
sidechain
}
,
)
return
(
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
)
return
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
def
supervised_chi_loss
(
...
...
@@ -264,10 +252,11 @@ def supervised_chi_loss(
)
->
torch
.
Tensor
:
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
aatype
,
residue_constants
.
restype_num
+
1
,
)
chi_pi_periodic
=
torch
.
einsum
(
"...ij,jk->ik"
,
"...ij,jk->ik"
,
residue_type_one_hot
.
type
(
angles_sin_cos
.
dtype
),
angles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
),
)
...
...
@@ -276,11 +265,9 @@ def supervised_chi_loss(
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
sum
((
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error_shifted
=
torch
.
sum
(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
...
...
@@ -295,14 +282,14 @@ def supervised_chi_loss(
loss
=
loss
+
chi_weight
*
sq_chi_loss
angle_norm
=
torch
.
sqrt
(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
0
)
norm_error
=
norm_error
.
permute
(
*
range
(
len
(
norm_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
angle_norm_loss
=
masked_mean
(
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
...
...
@@ -312,14 +299,13 @@ def supervised_chi_loss(
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_bins
=
logits
.
shape
[
-
1
]
bin_width
=
1.
/
num_bins
bin_width
=
1.
0
/
num_bins
bounds
=
torch
.
arange
(
start
=
0.5
*
bin_width
,
end
=
1.0
,
step
=
bin_width
,
device
=
logits
.
device
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
pred_lddt_ca
=
torch
.
sum
(
probs
*
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
probs
*
bounds
.
view
(
*
((
1
,)
*
len
(
probs
.
shape
[:
-
1
])),
*
bounds
.
shape
),
dim
=-
1
,
)
return
pred_lddt_ca
*
100
...
...
@@ -331,7 +317,7 @@ def lddt_loss(
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
cutoff
:
float
=
15.
0
,
no_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
...
...
@@ -339,55 +325,57 @@ def lddt_loss(
**
kwargs
,
)
->
torch
.
Tensor
:
n
=
all_atom_mask
.
shape
[
-
2
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
all_atom_positions
[...,
None
,
:]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
all_atom_positions
[...,
None
,
:]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
dmat_pred
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
all_atom_pred_pos
[...,
None
,
:]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
all_atom_pred_pos
[...,
None
,
:]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
))
*
(
1.0
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
)
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
score
=
(
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
(
dist_l1
<
0.5
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
)
score
=
score
*
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
norm
=
1.
0
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
score
=
score
.
detach
()
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
...
...
@@ -396,40 +384,39 @@ def lddt_loss(
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
))
loss
=
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
)
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
def
distogram_loss
(
logits
,
pseudo_beta
,
pseudo_beta_mask
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
,
logits
,
pseudo_beta
,
pseudo_beta_mask
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
,
eps
=
1e-6
,
**
kwargs
,
):
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
)
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
keepdims
=
True
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
,
)
true_bins
=
torch
.
sum
(
dists
>
boundaries
,
dim
=-
1
)
...
...
@@ -442,7 +429,7 @@ def distogram_loss(
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom
=
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
))
mean
=
errors
*
square_mask
...
...
@@ -450,7 +437,7 @@ def distogram_loss(
mean
=
mean
/
denom
[...,
None
]
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
return
mean
return
mean
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
...
...
@@ -469,7 +456,7 @@ def _calculate_expected_aligned_error(
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
return
(
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
bin_centers
[
-
1
]
bin_centers
[
-
1
]
,
)
...
...
@@ -480,7 +467,7 @@ def compute_predicted_aligned_error(
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
...
...
@@ -494,18 +481,16 @@ def compute_predicted_aligned_error(
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
aligned_confidence_probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
predicted_aligned_error
,
max_predicted_aligned_error
=
(
_calculate_expected_aligned_error
(
alignment_confidence_breaks
=
boundaries
,
aligned_distance_error_probs
=
aligned_confidence_probs
)
(
predicted_aligned_error
,
max_predicted_aligned_error
,
)
=
_calculate_expected_aligned_error
(
alignment_confidence_breaks
=
boundaries
,
aligned_distance_error_probs
=
aligned_confidence_probs
,
)
return
{
...
...
@@ -523,14 +508,11 @@ def compute_tm(
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
(
residue_weights
is
None
)
:
if
residue_weights
is
None
:
residue_weights
=
logits
.
new_ones
(
logits
.
shape
[
-
2
])
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
...
...
@@ -538,11 +520,11 @@ def compute_tm(
n
=
logits
.
shape
[
-
2
]
clipped_n
=
max
(
n
,
19
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
0
/
3
)
-
1.8
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
tm_per_bin
=
1.
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
tm_per_bin
=
1.
0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
...
...
@@ -554,12 +536,12 @@ def compute_tm(
def
tm_loss
(
logits
,
final_affine_tensor
,
backbone_affine_tensor
,
backbone_affine_mask
,
final_affine_tensor
,
backbone_affine_tensor
,
backbone_affine_mask
,
resolution
,
max_bin
=
31
,
no_bins
=
64
,
max_bin
=
31
,
no_bins
=
64
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
eps
=
1e-8
,
...
...
@@ -573,25 +555,18 @@ def tm_loss(
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
)
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
boundaries
=
boundaries
**
2
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
)
square_mask
=
(
...
...
@@ -599,15 +574,14 @@ def tm_loss(
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.5
# hack to help FP16 training along
scale
=
0.5
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
scale
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
@@ -623,11 +597,11 @@ def between_residue_bond_loss(
eps
=
1e-6
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
...
...
@@ -638,7 +612,7 @@ def between_residue_bond_loss(
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
...
...
@@ -659,126 +633,116 @@ def between_residue_bond_loss(
next_n_mask
=
pred_atom_mask
[...,
1
:,
0
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
# Compute loss for the C--N bond.
c_n_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
(
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_c_pos
-
next_n_pos
)
**
2
,
dim
=-
1
)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
)
next_is_proline
=
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
gt_length
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
)
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
gt_stddev
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
)
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
(
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_loss
=
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
# Compute loss for the angles.
ca_c_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
this_c_pos
)
**
2
,
dim
=-
1
)
)
n_ca_bond_length
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
next_n_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
)
/
ca_c_bond_length
[...,
None
]
c_n_unit_vec
=
(
next_n_pos
-
this_c_pos
)
/
c_n_bond_length
[...,
None
]
n_ca_unit_vec
=
(
next_ca_pos
-
next_n_pos
)
/
n_ca_bond_length
[...,
None
]
ca_c_n_cos_angle
=
torch
.
sum
(
c_ca_unit_vec
*
c_n_unit_vec
,
dim
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
ca_c_n_cos_angle_error
=
torch
.
sqrt
(
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
eps
+
(
ca_c_n_cos_angle
-
gt_angle
)
**
2
)
ca_c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
(
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
ca_c_n_loss
=
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
c_n_ca_cos_angle
=
torch
.
sum
((
-
c_n_unit_vec
)
*
n_ca_unit_vec
,
dim
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
c_n_ca_cos_angle_error
=
torch
.
sqrt
(
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
))
eps
+
torch
.
square
(
c_n_ca_cos_angle
-
gt_angle
)
)
c_n_ca_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
(
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_ca_loss
=
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
0.5
*
(
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
)
# Compute hard violations.
violation_mask
=
torch
.
max
(
torch
.
stack
(
[
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
[
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
],
dim
=-
2
,
),
dim
=-
2
),
dim
=-
2
,
)[
0
]
violation_mask
=
torch
.
maximum
(
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
,
)
return
{
'
c_n_loss_mean
'
:
c_n_loss
,
'
ca_c_n_loss_mean
'
:
ca_c_n_loss
,
'
c_n_ca_loss_mean
'
:
c_n_ca_loss
,
'
per_residue_loss_sum
'
:
per_residue_loss_sum
,
'
per_residue_violation_mask
'
:
violation_mask
"
c_n_loss_mean
"
:
c_n_loss
,
"
ca_c_n_loss_mean
"
:
ca_c_n_loss
,
"
c_n_ca_loss_mean
"
:
c_n_ca_loss
,
"
per_residue_loss_sum
"
:
per_residue_loss_sum
,
"
per_residue_violation_mask
"
:
violation_mask
,
}
...
...
@@ -792,12 +756,12 @@ def between_residue_clash_loss(
eps
=
1e-10
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
...
...
@@ -807,7 +771,7 @@ def between_residue_clash_loss(
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
...
...
@@ -816,33 +780,36 @@ def between_residue_clash_loss(
shape (N, 14)
"""
fp_type
=
atom14_pred_positions
.
dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
)
atom14_pred_positions
[...,
:,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask
=
(
atom14_atom_exists
[...,
:,
None
,
:,
None
]
*
atom14_atom_exists
[...,
None
,
:,
None
,
:]
atom14_atom_exists
[...,
:,
None
,
:,
None
]
*
atom14_atom_exists
[...,
None
,
:,
None
,
:]
).
type
(
fp_type
)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask
=
dists_mask
*
(
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
)
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
),
num_classes
=
14
...
...
@@ -860,74 +827,69 @@ def between_residue_clash_loss(
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
(
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
)
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
c_n_bonds
=
(
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
)
dists_mask
=
dists_mask
*
(
1.
-
c_n_bonds
)
dists_mask
=
dists_mask
*
(
1.
0
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"CYS"
]
cys_sg_idx
=
cys
.
index
(
'
SG
'
)
cys_sg_idx
=
cys
.
index
(
"
SG
"
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
cys_sg_idx
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
1
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
1
).
squeeze
(
-
1
)
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
cys_sg_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
disulfide_bonds
=
(
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:])
dists_mask
=
dists_mask
*
(
1.
-
disulfide_bonds
)
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:]
)
dists_mask
=
dists_mask
*
(
1.0
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound
=
dists_mask
*
(
atom14_atom_radius
[...,
:,
None
,
:,
None
]
+
atom14_atom_radius
[...,
None
,
:,
None
,
:]
atom14_atom_radius
[...,
:,
None
,
:,
None
]
+
atom14_atom_radius
[...,
None
,
:,
None
,
:]
)
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error
=
dists_mask
*
torch
.
nn
.
functional
.
relu
(
dists_lower_bound
-
overlap_tolerance_soft
-
dists
)
# Compute the mean loss.
# shape ()
mean_loss
=
(
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
)
mean_loss
=
torch
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
torch
.
sum
(
dists_mask
))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
))
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
)
return
{
'
mean_loss
'
:
mean_loss
,
# shape ()
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
# shape (N, 14)
'
per_atom_clash_mask
'
:
per_atom_clash_mask
# shape (N, 14)
"
mean_loss
"
:
mean_loss
,
# shape ()
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
# shape (N, 14)
"
per_atom_clash_mask
"
:
per_atom_clash_mask
,
# shape (N, 14)
}
...
...
@@ -940,54 +902,53 @@ def within_residue_violations(
eps
=
1e-10
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions ([*, N, 14, 3]):
atom14_pred_positions ([*, N, 14, 3]):
Predicted positions of atoms in global prediction frame.
atom14_atom_exists ([*, N, 14]):
atom14_atom_exists ([*, N, 14]):
Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound ([*, N, 14]):
atom14_dists_lower_bound ([*, N, 14]):
Lower bound on allowed distances.
atom14_dists_upper_bound ([*, N, 14]):
atom14_dists_upper_bound ([*, N, 14]):
Upper bound on allowed distances
tighten_bounds_for_loss ([*, N]):
tighten_bounds_for_loss ([*, N]):
Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum' ([*, N, 14]):
* 'per_atom_loss_sum' ([*, N, 14]):
sum of all clash losses per atom, shape
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks
=
(
1.
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
)
dists_masks
=
1.0
-
torch
.
eye
(
14
,
device
=
atom14_atom_exists
.
device
)[
None
]
dists_masks
=
dists_masks
.
reshape
(
*
((
1
,)
*
len
(
atom14_atom_exists
.
shape
[:
-
2
])),
*
dists_masks
.
shape
)
dists_masks
=
(
atom14_atom_exists
[...,
:,
:,
None
]
*
atom14_atom_exists
[...,
:,
None
,
:]
*
dists_masks
atom14_atom_exists
[...,
:,
:,
None
]
*
atom14_atom_exists
[...,
:,
None
,
:]
*
dists_masks
)
# Distance matrix
dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
:,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
atom14_pred_positions
[...,
:,
:,
None
,
:]
-
atom14_pred_positions
[...,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
...
...
@@ -999,34 +960,26 @@ def within_residue_violations(
dists
-
(
atom14_dists_upper_bound
-
tighten_bounds_for_loss
)
)
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
# Compute the per atom loss sum.
per_atom_loss_sum
=
(
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
)
per_atom_loss_sum
=
torch
.
sum
(
loss
,
dim
=-
2
)
+
torch
.
sum
(
loss
,
dim
=-
1
)
# Compute the violations mask.
violations
=
(
dists_masks
*
(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
violations
=
dists_masks
*
(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
)
return
{
'
per_atom_loss_sum
'
:
per_atom_loss_sum
,
'
per_atom_violations
'
:
per_atom_violations
"
per_atom_loss_sum
"
:
per_atom_loss_sum
,
"
per_atom_violations
"
:
per_atom_violations
,
}
def
find_structural_violations
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
atom14_pred_positions
:
torch
.
Tensor
,
...
...
@@ -1035,7 +988,7 @@ def find_structural_violations(
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations
=
between_residue_bond_loss
(
pred_atom_positions
=
atom14_pred_positions
,
...
...
@@ -1043,9 +996,9 @@ def find_structural_violations(
residue_index
=
batch
[
"residue_index"
],
aatype
=
batch
[
"aatype"
],
tolerance_factor_soft
=
violation_tolerance_factor
,
tolerance_factor_hard
=
violation_tolerance_factor
tolerance_factor_hard
=
violation_tolerance_factor
,
)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
...
...
@@ -1053,14 +1006,12 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
atom14_pred_positions
=
atom14_pred_positions
,
...
...
@@ -1068,32 +1019,28 @@ def find_structural_violations(
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
violation_tolerance_factor
bond_length_tolerance_factor
=
violation_tolerance_factor
,
)
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
atom14_dists_lower_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
])[
batch
[
"aatype"
]
]
)
atom14_dists_upper_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
])[
batch
[
"aatype"
]
]
)
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
]
)[
batch
[
"aatype"
]]
atom14_dists_upper_bound
=
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
]
)[
batch
[
"aatype"
]]
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
tighten_bounds_for_loss
=
0.0
,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
...
...
@@ -1104,49 +1051,52 @@ def find_structural_violations(
torch
.
max
(
between_residue_clashes
[
"per_atom_clash_mask"
],
dim
=-
1
)[
0
],
torch
.
max
(
residue_violations
[
"per_atom_violations"
],
dim
=-
1
)[
0
],
],
torch
.
max
(
residue_violations
[
"per_atom_violations"
],
dim
=-
1
)[
0
],
],
dim
=-
1
,
),
),
dim
=-
1
,
)[
0
]
return
{
'between_residues'
:
{
'bonds_c_n_loss_mean'
:
connection_violations
[
"c_n_loss_mean"
],
# ()
'angles_ca_c_n_loss_mean'
:
connection_violations
[
"ca_c_n_loss_mean"
],
# ()
'angles_c_n_ca_loss_mean'
:
connection_violations
[
"c_n_ca_loss_mean"
],
# ()
'connections_per_residue_loss_sum'
:
connection_violations
[
"per_residue_loss_sum"
],
# (N)
'connections_per_residue_violation_mask'
:
connection_violations
[
"per_residue_violation_mask"
],
# (N)
'clashes_mean_loss'
:
between_residue_clashes
[
"mean_loss"
],
# ()
'clashes_per_atom_loss_sum'
:
between_residue_clashes
[
"per_atom_loss_sum"
],
# (N, 14)
'clashes_per_atom_clash_mask'
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
"between_residues"
:
{
"bonds_c_n_loss_mean"
:
connection_violations
[
"c_n_loss_mean"
],
# ()
"angles_ca_c_n_loss_mean"
:
connection_violations
[
"ca_c_n_loss_mean"
],
# ()
"angles_c_n_ca_loss_mean"
:
connection_violations
[
"c_n_ca_loss_mean"
],
# ()
"connections_per_residue_loss_sum"
:
connection_violations
[
"per_residue_loss_sum"
],
# (N)
"connections_per_residue_violation_mask"
:
connection_violations
[
"per_residue_violation_mask"
],
# (N)
"clashes_mean_loss"
:
between_residue_clashes
[
"mean_loss"
],
# ()
"clashes_per_atom_loss_sum"
:
between_residue_clashes
[
"per_atom_loss_sum"
],
# (N, 14)
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
},
'within_residues'
:
{
'per_atom_loss_sum'
:
residue_violations
[
"per_atom_loss_sum"
],
# (N, 14)
'per_atom_violations'
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
"per_atom_loss_sum"
],
# (N, 14)
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
},
'total_per_residue_violations_mask'
:
per_residue_violations_mask
,
# (N)
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
def
find_structural_violations_np
(
batch
:
Dict
[
str
,
np
.
ndarray
],
atom14_pred_positions
:
np
.
ndarray
,
config
:
ml_collections
.
ConfigDict
config
:
ml_collections
.
ConfigDict
,
)
->
Dict
[
str
,
np
.
ndarray
]:
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
...
...
@@ -1161,17 +1111,17 @@ def find_structural_violations_np(
def
extreme_ca_ca_distance_violations
(
pred_atom_positions
:
torch
.
Tensor
,
# (N, 37(14), 3)
pred_atom_mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
eps
=
1e-6
,
pred_atom_positions
:
torch
.
Tensor
,
# (N, 37(14), 3)
pred_atom_mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
eps
=
1e-6
,
)
->
torch
.
Tensor
:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
...
...
@@ -1185,13 +1135,13 @@ def extreme_ca_ca_distance_violations(
this_ca_mask
=
pred_atom_mask
[...,
:
-
1
,
1
]
next_ca_pos
=
pred_atom_positions
[...,
1
:,
1
,
:]
next_ca_mask
=
pred_atom_mask
[...,
1
:,
1
]
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
)
has_no_gap_mask
=
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
ca_ca_distance
=
torch
.
sqrt
(
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
violations
=
(
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
)
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
mean
=
masked_mean
(
mask
,
violations
,
-
1
)
return
mean
...
...
@@ -1202,18 +1152,18 @@ def compute_violation_metrics(
atom14_pred_positions
:
torch
.
Tensor
,
# (N, 14, 3)
violations
:
Dict
[
str
,
torch
.
Tensor
],
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Compute several metrics to assess the structural violations."""
"""Compute several metrics to assess the structural violations."""
ret
=
{}
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
"atom14_atom_exists"
],
residue_index
=
batch
[
"residue_index"
]
residue_index
=
batch
[
"residue_index"
]
,
)
ret
[
"violations_extreme_ca_ca_distance"
]
=
extreme_ca_ca_violations
ret
[
"violations_between_residue_bond"
]
=
masked_mean
(
batch
[
"seq_mask"
],
violations
[
"between_residues"
][
'
connections_per_residue_violation_mask
'
"
connections_per_residue_violation_mask
"
],
dim
=-
1
,
)
...
...
@@ -1221,7 +1171,7 @@ def compute_violation_metrics(
mask
=
batch
[
"seq_mask"
],
value
=
torch
.
max
(
violations
[
"between_residues"
][
"clashes_per_atom_clash_mask"
],
dim
=-
1
dim
=-
1
,
)[
0
],
dim
=-
1
,
)
...
...
@@ -1250,7 +1200,6 @@ def compute_violation_metrics_np(
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
violations
=
tree_map
(
to_tensor
,
violations
,
np
.
ndarray
)
out
=
compute_violation_metrics
(
batch
,
atom14_pred_positions
,
violations
)
to_np
=
lambda
x
:
np
.
array
(
x
)
...
...
@@ -1265,15 +1214,15 @@ def violation_loss(
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
l_clash
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
l_clash
)
return
loss
...
...
@@ -1286,12 +1235,12 @@ def compute_renamed_ground_truth(
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Find optimal renaming of ground truth based on the predicted positions.
Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
...
...
@@ -1313,50 +1262,53 @@ def compute_renamed_ground_truth(
"""
pred_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_pred_positions
[...,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_pred_positions
[...,
None
,
:,
None
,
:]
-
atom14_pred_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
atom14_gt_positions
=
batch
[
"atom14_gt_positions"
]
gt_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
atom14_alt_gt_positions
=
batch
[
"atom14_alt_gt_positions"
]
alt_gt_dists
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
eps
+
torch
.
sum
(
(
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:]
-
atom14_alt_gt_positions
[...,
None
,
:,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
atom14_gt_exists
=
batch
[
"atom14_gt_exists"
]
atom14_atom_is_ambiguous
=
batch
[
"atom14_atom_is_ambiguous"
]
mask
=
(
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
,
:]
*
(
1.
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
,
:]
*
(
1.
0
-
atom14_atom_is_ambiguous
[...,
None
,
:,
None
,
:])
)
per_res_lddt
=
torch
.
sum
(
mask
*
lddt
,
dim
=
(
-
1
,
-
2
,
-
3
))
...
...
@@ -1366,16 +1318,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
type
(
fp_type
)
renamed_atom14_gt_positions
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
])
*
atom14_gt_positions
+
alt_naming_is_better
[...,
None
,
None
]
*
atom14_alt_gt_positions
)
1.0
-
alt_naming_is_better
[...,
None
,
None
]
)
*
atom14_gt_positions
+
alt_naming_is_better
[
...,
None
,
None
]
*
atom14_alt_gt_positions
renamed_atom14_gt_mask
=
(
(
1.
-
alt_naming_is_better
[...,
None
])
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
)
1.0
-
alt_naming_is_better
[...,
None
]
)
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
return
{
"alt_naming_is_better"
:
alt_naming_is_better
,
...
...
@@ -1398,10 +1350,9 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
@@ -1409,10 +1360,9 @@ def experimentally_resolved_loss(
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
)
# FP16-friendly averaging. Equivalent to:
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
...
...
@@ -1435,7 +1385,7 @@ def compute_drmsd(structure_1, structure_2):
d1
=
d1
**
2
d2
=
d2
**
2
d1
=
torch
.
sqrt
(
torch
.
sum
(
d1
,
dim
=-
1
))
d2
=
torch
.
sqrt
(
torch
.
sum
(
d2
,
dim
=-
1
))
...
...
@@ -1450,81 +1400,74 @@ def compute_drmsd(structure_1, structure_2):
class
AlphaFoldLoss
(
nn
.
Module
):
""" Aggregation of the various losses described in the supplement """
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
)
:
def
forward
(
self
,
out
,
batch
):
if
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
:
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
violation
,
)
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
))
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
)
)
loss_fns
=
{
"distogram"
:
lambda
:
distogram_loss
(
logits
=
out
[
"distogram_logits"
],
**
{
**
batch
,
**
self
.
config
.
distogram
},
),
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
logits
=
out
[
"experimentally_resolved_logits"
],
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
),
"fape"
:
lambda
:
fape_loss
(
out
,
batch
,
self
.
config
.
fape
,
),
"lddt"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
lddt
},
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
),
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
"distogram"
:
lambda
:
distogram_loss
(
logits
=
out
[
"distogram_logits"
],
**
{
**
batch
,
**
self
.
config
.
distogram
},
),
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
logits
=
out
[
"experimentally_resolved_logits"
],
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
),
"fape"
:
lambda
:
fape_loss
(
out
,
batch
,
self
.
config
.
fape
,
),
"lddt"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
lddt
},
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
),
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
}
cum_loss
=
0
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
if
(
weight
)
:
#print(k)
if
weight
:
#
print(k)
loss
=
loss_fn
()
#print(weight * loss)
#
print(weight * loss)
cum_loss
=
cum_loss
+
weight
*
loss
#print(cum_loss)
#
print(cum_loss)
return
cum_loss
openfold/utils/tensor_utils.py
View file @
07e64267
...
...
@@ -49,11 +49,11 @@ def dict_multimap(fn, dicts):
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
return
new_dict
...
...
@@ -83,7 +83,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
def
dict_map
(
fn
,
dic
,
leaf_type
):
new_dict
=
{}
for
k
,
v
in
dic
.
items
():
if
(
type
(
v
)
is
dict
)
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
else
:
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
...
...
@@ -92,76 +92,77 @@ def dict_map(fn, dic, leaf_type):
def
tree_map
(
fn
,
tree
,
leaf_type
):
if
(
isinstance
(
tree
,
dict
)
)
:
if
isinstance
(
tree
,
dict
):
return
dict_map
(
fn
,
tree
,
leaf_type
)
elif
(
isinstance
(
tree
,
list
)
)
:
elif
isinstance
(
tree
,
list
):
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
elif
(
isinstance
(
tree
,
tuple
)
)
:
elif
isinstance
(
tree
,
tuple
):
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
elif
(
isinstance
(
tree
,
leaf_type
)
)
:
elif
isinstance
(
tree
,
leaf_type
):
return
fn
(
tree
)
else
:
print
(
type
(
tree
))
raise
ValueError
(
"Not supported"
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
Returns:
The reassembled output of the layer on the inputs.
"""
if
(
not
(
len
(
inputs
)
>
0
)
)
:
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
def
fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
(
tree_type
is
dict
)
:
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
fetch_dims
(
v
))
elif
(
tree_type
is
list
or
tree_type
is
tuple
)
:
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
fetch_dims
(
t
))
elif
(
tree_type
is
torch
.
Tensor
)
:
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
)
:
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
...
...
@@ -172,40 +173,42 @@ def chunk_layer(
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
(
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
(
out
is
None
)
:
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
(
out_type
is
dict
):
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
(
type
(
v
)
is
dict
)
:
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
(
out_type
is
tuple
)
:
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
(
out_type
is
torch
.
Tensor
)
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
...
...
@@ -214,4 +217,4 @@ def chunk_layer(
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
return
out
tests/compare_utils.py
View file @
07e64267
...
...
@@ -15,7 +15,7 @@ from openfold.utils.import_weights import import_jax_weights_
from
tests.config
import
consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
os
.
environ
[
"XLA_PYTHON_CLIENT_ALLOCATOR"
]
=
"platform"
os
.
environ
[
"JAX_PLATFORM_NAME"
]
=
"gpu"
...
...
@@ -30,17 +30,15 @@ def skip_unless_alphafold_installed():
def
import_alphafold
():
"""
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if
(
"alphafold"
in
sys
.
modules
):
return
sys
.
modules
[
"alphafold"
]
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if
"alphafold"
in
sys
.
modules
:
return
sys
.
modules
[
"alphafold"
]
module
=
importlib
.
import_module
(
"alphafold"
)
# Forcefully import alphafold's submodules
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
)
)
submodules
=
pkgutil
.
walk_packages
(
module
.
__path__
,
prefix
=
(
"alphafold."
))
for
submodule_info
in
submodules
:
importlib
.
import_module
(
submodule_info
.
name
)
sys
.
modules
[
"alphafold"
]
=
module
...
...
@@ -57,16 +55,18 @@ def get_alphafold_config():
_param_path
=
"openfold/resources/params/params_model_1_ptm.npz"
_model
=
None
def
get_global_pretrained_openfold
():
global
_model
if
(
_model
is
None
)
:
if
_model
is
None
:
_model
=
AlphaFold
(
model_config
(
"model_1_ptm"
).
model
)
_model
=
_model
.
eval
()
if
(
not
os
.
path
.
exists
(
_param_path
)
)
:
if
not
os
.
path
.
exists
(
_param_path
):
raise
FileNotFoundError
(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
)
import_jax_weights_
(
_model
,
_param_path
,
version
=
"model_1_ptm"
)
_model
=
_model
.
cuda
()
...
...
@@ -74,9 +74,11 @@ def get_global_pretrained_openfold():
_orig_weights
=
None
def
_get_orig_weights
():
global
_orig_weights
if
(
_orig_weights
is
None
)
:
if
_orig_weights
is
None
:
_orig_weights
=
np
.
load
(
_param_path
)
return
_orig_weights
...
...
@@ -84,22 +86,19 @@ def _get_orig_weights():
def
_remove_key_prefix
(
d
,
prefix
):
for
k
,
v
in
list
(
d
.
items
()):
if
(
k
.
startswith
(
prefix
)
)
:
if
k
.
startswith
(
prefix
):
d
.
pop
(
k
)
d
[
k
[
len
(
prefix
):]]
=
v
d
[
k
[
len
(
prefix
)
:]]
=
v
def
fetch_alphafold_module_weights
(
weight_path
):
orig_weights
=
_get_orig_weights
()
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
if
(
'/'
in
weight_path
):
spl
=
weight_path
.
split
(
'/'
)
params
=
{
k
:
v
for
k
,
v
in
orig_weights
.
items
()
if
weight_path
in
k
}
if
"/"
in
weight_path
:
spl
=
weight_path
.
split
(
"/"
)
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
module_name
=
spl
[
-
1
]
prefix
=
'/'
.
join
(
spl
[:
-
1
])
+
'/'
prefix
=
"/"
.
join
(
spl
[:
-
1
])
+
"/"
_remove_key_prefix
(
params
,
prefix
)
params
=
alphafold
.
model
.
utils
.
flat_params_to_haiku
(
params
)
return
params
tests/config.py
View file @
07e64267
import
ml_collections
as
mlc
consts
=
mlc
.
ConfigDict
({
"batch_size"
:
2
,
"n_res"
:
11
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"eps"
:
5e-4
,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m"
:
256
,
"c_z"
:
128
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
})
consts
=
mlc
.
ConfigDict
(
{
"batch_size"
:
2
,
"n_res"
:
11
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"eps"
:
5e-4
,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m"
:
256
,
"c_z"
:
128
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
}
)
tests/data_utils.py
View file @
07e64267
...
...
@@ -18,7 +18,7 @@ from scipy.spatial.transform import Rotation
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
...
...
@@ -28,28 +28,31 @@ def random_template_feats(n_templ, n, batch_size=None):
"template_all_atom_masks"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
if
(
batch_size
is
not
None
)
:
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
return
batch
...
...
@@ -63,7 +66,9 @@ def random_affines_vector(dim):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
return
affines
.
reshape
(
*
dim
,
7
)
...
...
@@ -77,9 +82,10 @@ def random_affines_4x4(dim):
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,).
astype
(
np
.
float32
)
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,
).
astype
(
np
.
float32
)
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
tests/test_embedders.py
View file @
07e64267
...
...
@@ -24,30 +24,30 @@ from openfold.model.embedders import (
class
TestInputEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
tf_dim
=
2
msa_dim
=
3
c_z
=
5
c_m
=
7
relpos_k
=
11
b
=
13
n_res
=
17
n_clust
=
19
tf
=
torch
.
rand
((
b
,
n_res
,
tf_dim
))
ri
=
torch
.
rand
((
b
,
n_res
))
msa
=
torch
.
rand
((
b
,
n_clust
,
n_res
,
msa_dim
))
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
,
ri
,
msa
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
class
TestRecyclingEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
n
=
3
c_z
=
5
...
...
@@ -66,7 +66,7 @@ class TestRecyclingEmbedder(unittest.TestCase):
self
.
assertTrue
(
z
.
shape
==
(
batch_size
,
n
,
n
,
c_z
))
self
.
assertTrue
(
m_1
.
shape
==
(
batch_size
,
n
,
c_m
))
class
TestTemplateAngleEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
...
...
@@ -80,13 +80,11 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
template_angle_dim
,
c_m
,
)
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
template_angle_dim
))
x
=
tae
(
x
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
)
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
c_m
))
class
TestTemplatePairEmbedder
(
unittest
.
TestCase
):
...
...
@@ -96,20 +94,17 @@ class TestTemplatePairEmbedder(unittest.TestCase):
n_res
=
5
template_pair_dim
=
7
c_t
=
11
tpe
=
TemplatePairEmbedder
(
template_pair_dim
,
c_t
,
)
x
=
torch
.
rand
((
batch_size
,
n_templ
,
n_res
,
n_res
,
template_pair_dim
))
x
=
tpe
(
x
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
)
)
self
.
assertTrue
(
x
.
shape
==
(
batch_size
,
n_templ
,
n_res
,
n_res
,
c_t
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_evoformer.py
View file @
07e64267
...
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestEvoformerStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
...
...
@@ -91,56 +91,54 @@ class TestEvoformerStack(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
ei
=
alphafold
.
model
.
modules
.
EvoformerIteration
(
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
c_e
,
config
.
model
.
global_config
,
is_extra_msa
=
False
)
return
ei
(
activations
,
masks
,
is_training
=
False
)
f
=
hk
.
transform
(
run_ei
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
activations
=
{
'
msa
'
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
masks
=
{
'
msa
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
'
pair
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
"
msa
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_seq
,
n_res
)).
astype
(
np
.
float32
),
"
pair
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
),
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
out_gt
=
f
.
apply
(
params
,
key
,
activations
,
masks
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt_msa
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"msa"
]))
out_gt_pair
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"pair"
]))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
_mask_trans
=
False
,
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
))
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
))
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
23
n_res
=
5
...
...
@@ -180,8 +178,24 @@ class TestExtraMSAStack(unittest.TestCase):
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,))
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
s_t
,
n_res
,
),
)
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
,
),
)
shape_z_before
=
z
.
shape
...
...
@@ -191,7 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
class
TestMSATransition
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
s_t
=
3
n_r
=
5
...
...
@@ -214,39 +228,43 @@ class TestMSATransition(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_trans
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
msa_transition
,
c_e
.
msa_transition
,
config
.
model
.
global_config
,
name
=
"msa_transition"
name
=
"msa_transition"
,
)
act
=
msa_trans
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_transition
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
msa_mask
=
np
.
ones
((
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# no mask here either
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_feats.py
View file @
07e64267
...
...
@@ -26,14 +26,14 @@ from openfold.np.residue_constants import (
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
tree_map
,
tensor_tree_map
,
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -48,21 +48,21 @@ class TestFeats(unittest.TestCase):
all_atom_pos
,
all_atom_mask
,
)
f
=
hk
.
transform
(
test_pbf
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
all_atom_pos
=
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
))
out_gt_pos
,
out_gt_mask
=
f
.
apply
(
{},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt_pos
=
torch
.
tensor
(
np
.
array
(
out_gt_pos
.
block_until_ready
()))
out_gt_mask
=
torch
.
tensor
(
np
.
array
(
out_gt_mask
.
block_until_ready
()))
out_repro_pos
,
out_repro_mask
=
feats
.
pseudo_beta_fn
(
torch
.
tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
all_atom_pos
).
cuda
(),
...
...
@@ -70,7 +70,7 @@ class TestFeats(unittest.TestCase):
)
out_repro_pos
=
out_repro_pos
.
cpu
()
out_repro_mask
=
out_repro_mask
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt_pos
-
out_repro_pos
))
<
consts
.
eps
)
...
...
@@ -82,26 +82,26 @@ class TestFeats(unittest.TestCase):
def
test_atom37_to_torsion_angles_compare
(
self
):
def
run_test
(
aatype
,
all_atom_pos
,
all_atom_mask
):
return
alphafold
.
model
.
all_atom
.
atom37_to_torsion_angles
(
aatype
,
all_atom_pos
,
aatype
,
all_atom_pos
,
all_atom_mask
,
placeholder_for_undefined
=
False
,
)
f
=
hk
.
transform
(
run_test
)
n_templ
=
7
n_templ
=
7
n_res
=
13
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_templ
,
n_res
)).
astype
(
np
.
int64
)
all_atom_pos
=
np
.
random
.
rand
(
n_templ
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)
)
.
astype
(
np
.
float32
)
all_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_templ
,
n_res
,
37
)).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
all_atom_pos
,
all_atom_mask
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
)
out_repro
=
feats
.
atom37_to_torsion_angles
(
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
as_tensor
(
all_atom_pos
).
cuda
(),
...
...
@@ -110,20 +110,21 @@ class TestFeats(unittest.TestCase):
tasc
=
out_repro
[
"torsion_angles_sin_cos"
].
cpu
()
atasc
=
out_repro
[
"alt_torsion_angles_sin_cos"
].
cpu
()
tam
=
out_repro
[
"torsion_angles_mask"
].
cpu
()
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
)
)
<
0.01
torch
.
mean
(
torch
.
abs
(
out_gt
[
"torsion_angles_sin_cos"
]
-
tasc
))
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
))
<
0.01
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
[
"alt_torsion_angles_sin_cos"
]
-
atasc
)
)
<
0.01
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
"torsion_angles_mask"
]
-
tam
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_atom37_to_frames_compare
(
self
):
...
...
@@ -131,48 +132,50 @@ class TestFeats(unittest.TestCase):
return
alphafold
.
model
.
all_atom
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
)
f
=
hk
.
transform
(
run_atom37_to_frames
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
batch
=
{
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
**
batch
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
four_by_four
=
torch
.
zeros
(
*
flat12
.
shape
[:
-
1
],
4
,
4
)
four_by_four
[...,
:
3
,
:
3
]
=
rot
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_gt_frames"
]
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
data_transforms
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
for
k
,
v
in
out_gt
.
items
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
...
...
@@ -190,56 +193,50 @@ class TestFeats(unittest.TestCase):
aas
=
torch
.
stack
([
aas
for
_
in
range
(
batch_size
)])
frames
=
feats
.
torsion_angles_to_frames
(
ts
,
angles
,
aas
,
ts
,
angles
,
aas
,
torch
.
tensor
(
restype_rigid_group_default_frame
),
)
self
.
assertTrue
(
frames
.
shape
==
(
batch_size
,
n
,
8
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_torsion_angles_to_frames_compare
(
self
):
def
run_torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
aatype
,
backb_to_global
,
torsion_angles_sin_cos
):
return
alphafold
.
model
.
all_atom
.
torsion_angles_to_frames
(
aatype
,
backb_to_global
,
torsion_angles_sin_cos
,
)
f
=
hk
.
transform
(
run_torsion_angles_to_frames
)
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
,
torsion_angles_sin_cos
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out
=
feats
.
torsion_angles_to_frames
(
transformations
.
cuda
(),
torch
.
as_tensor
(
torsion_angles_sin_cos
).
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
)
)
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
...
...
@@ -250,9 +247,9 @@ class TestFeats(unittest.TestCase):
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
)
...
...
@@ -275,7 +272,7 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_mask
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -285,34 +282,32 @@ class TestFeats(unittest.TestCase):
return
am
.
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
affines
)
f
=
hk
.
transform
(
run_f
)
n_res
=
consts
.
n_res
aatype
=
np
.
random
.
randint
(
0
,
21
,
size
=
(
n_res
,))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
out_gt
=
f
.
apply
(
{},
None
,
aatype
,
rigids
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
torch
.
stack
(
[
torch
.
as_tensor
(
np
.
array
(
x
))
for
x
in
out_gt
],
dim
=-
1
)
out_repro
=
feats
.
frames_and_literature_positions_to_atom14_pos
(
transformations
.
cuda
(),
transformations
.
cuda
(),
torch
.
as_tensor
(
aatype
).
cuda
(),
torch
.
tensor
(
restype_rigid_group_default_frame
).
cuda
(),
torch
.
tensor
(
restype_atom14_to_rigid_group
).
cuda
(),
torch
.
tensor
(
restype_atom14_mask
).
cuda
(),
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_import_weights.py
View file @
07e64267
...
...
@@ -24,13 +24,14 @@ from openfold.utils.import_weights import import_jax_weights_
class
TestImportWeights
(
unittest
.
TestCase
):
def
test_import_jax_weights_
(
self
):
npz_path
=
"openfold/resources/params/params_model_1_ptm.npz"
c
=
model_config
(
"model_1_ptm"
)
c
.
globals
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
.
model
)
import_jax_weights_
(
model
,
npz_path
,
model
,
npz_path
,
)
data
=
np
.
load
(
npz_path
)
...
...
@@ -38,23 +39,34 @@ class TestImportWeights(unittest.TestCase):
test_pairs
=
[
# Normal linear weight
(
torch
.
as_tensor
(
data
[
prefix
+
"structure_module/initial_projection//weights"
]
).
transpose
(
-
1
,
-
2
),
model
.
structure_module
.
linear_in
.
weight
),
(
torch
.
as_tensor
(
data
[
prefix
+
"structure_module/initial_projection//weights"
]
).
transpose
(
-
1
,
-
2
),
model
.
structure_module
.
linear_in
.
weight
,
),
# Normal layer norm param
(
torch
.
as_tensor
(
data
[
prefix
+
"evoformer/prev_pair_norm//offset"
],
),
model
.
recycling_embedder
.
layer_norm_z
.
bias
),
(
torch
.
as_tensor
(
data
[
prefix
+
"evoformer/prev_pair_norm//offset"
],
),
model
.
recycling_embedder
.
layer_norm_z
.
bias
,
),
# From a stack
(
torch
.
as_tensor
(
data
[
prefix
+
(
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][
1
].
transpose
(
-
1
,
-
2
)),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,),
(
torch
.
as_tensor
(
data
[
prefix
+
(
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][
1
].
transpose
(
-
1
,
-
2
)
),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
),
]
for
w_alpha
,
w_repro
in
test_pairs
:
...
...
tests/test_loss.py
View file @
07e64267
...
...
@@ -41,15 +41,15 @@ from openfold.utils.loss import (
tm_loss
,
)
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
tree_map
,
tensor_tree_map
,
dict_multimap
,
)
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_vector
,
random_affines_4x4
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -99,12 +99,19 @@ class TestLoss(unittest.TestCase):
pred_pos
=
torch
.
rand
(
bs
,
n
,
14
,
3
)
pred_atom_mask
=
torch
.
randint
(
0
,
2
,
(
bs
,
n
,
14
))
residue_index
=
torch
.
arange
(
n
).
unsqueeze
(
0
)
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,))
aatype
=
torch
.
randint
(
0
,
22
,
(
bs
,
n
,
),
)
between_residue_bond_loss
(
pred_pos
,
pred_atom_mask
,
residue_index
,
residue_index
,
aatype
,
)
...
...
@@ -117,27 +124,26 @@ class TestLoss(unittest.TestCase):
residue_index
,
aatype
,
)
f
=
hk
.
transform
(
run_brbl
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)
).
astype
(
np
.
float32
)
pred_atom_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
residue_index
=
np
.
arange
(
n_res
)
aatype
=
np
.
random
.
randint
(
0
,
22
,
(
n_res
,))
out_gt
=
f
.
apply
(
{},
None
,
pred_pos
,
pred_atom_mask
,
{},
None
,
pred_pos
,
pred_atom_mask
,
residue_index
,
aatype
,
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_bond_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
pred_atom_mask
).
cuda
(),
...
...
@@ -145,13 +151,12 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
aatype
).
cuda
(),
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
def
test_run_between_residue_clash_loss
(
self
):
bs
=
consts
.
batch_size
n
=
consts
.
n_res
...
...
@@ -164,7 +169,7 @@ class TestLoss(unittest.TestCase):
loss
=
between_residue_clash_loss
(
pred_pos
,
pred_atom_mask
,
atom14_atom_radius
,
atom14_atom_radius
,
residue_index
,
)
...
...
@@ -185,10 +190,13 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,)
res_ind
=
np
.
arange
(
n_res
,
)
out_gt
=
f
.
apply
(
{},
None
,
{},
None
,
pred_pos
,
atom_exists
,
atom_radius
,
...
...
@@ -196,7 +204,7 @@ class TestLoss(unittest.TestCase):
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
out_repro
=
between_residue_clash_loss
(
torch
.
tensor
(
pred_pos
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
...
...
@@ -204,7 +212,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
res_ind
).
cuda
(),
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
for
k
in
out_gt
.
keys
():
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
...
...
@@ -221,7 +229,7 @@ class TestLoss(unittest.TestCase):
}
pred_pos
=
torch
.
rand
(
n
,
14
,
3
)
config
=
{
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
...
...
@@ -242,50 +250,44 @@ class TestLoss(unittest.TestCase):
os
.
chdir
(
cwd
)
return
loss
f
=
hk
.
transform
(
run_fsv
)
n_res
=
consts
.
n_res
batch
=
{
"atom14_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
),
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
}
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
({
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
})
out_gt
=
f
.
apply
(
{},
None
,
batch
,
pred_pos
,
config
config
=
mlc
.
ConfigDict
(
{
"clash_overlap_tolerance"
:
1.5
,
"violation_tolerance_factor"
:
12.0
,
}
)
out_gt
=
f
.
apply
({},
None
,
batch
,
pred_pos
,
config
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
copy
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
out_repro
=
find_structural_violations
(
batch
,
torch
.
tensor
(
pred_pos
).
cuda
(),
**
config
,
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
def
compare
(
out
):
gt
,
repro
=
out
assert
(
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
gt
-
repro
))
<
consts
.
eps
dict_multimap
(
compare
,
[
out_gt
,
out_repro
])
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -295,44 +297,45 @@ class TestLoss(unittest.TestCase):
batch
,
atom14_pred_pos
,
)
f
=
hk
.
transform
(
run_crgt
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
jax
.
tree_map
(
lambda
x
:
torch
.
tensor
(
np
.
array
(
x
)),
out_gt
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
out_repro
=
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
in
out_repro
:
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
...
...
@@ -346,84 +349,76 @@ class TestLoss(unittest.TestCase):
config
.
model
.
heads
.
masked_msa
,
config
.
model
.
global_config
)
return
msa_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_msa_loss
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_seq
,
23
).
astype
(
np
.
float32
),
}
batch
=
{
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_distogram_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_distogram
=
config
.
model
.
heads
.
distogram
def
run_distogram_loss
(
value
,
batch
):
dist_head
=
alphafold
.
model
.
modules
.
DistogramHead
(
c_distogram
,
config
.
model
.
global_config
)
return
dist_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_distogram_loss
)
n_res
=
consts
.
n_res
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
n_res
,
c_distogram
.
num_bins
).
astype
(
np
.
float32
),
"bin_edges"
:
np
.
linspace
(
c_distogram
.
first_break
,
c_distogram
.
last_break
,
c_distogram
.
num_bins
,
)
)
,
}
batch
=
{
"pseudo_beta"
:
np
.
random
.
rand
(
n_res
,
3
).
astype
(
np
.
float32
),
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
"pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,))
,
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
distogram_loss
(
logits
=
value
[
"logits"
],
...
...
@@ -431,66 +426,64 @@ class TestLoss(unittest.TestCase):
max_bin
=
c_distogram
.
last_break
,
no_bins
=
c_distogram
.
num_bins
,
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_experimentally_resolved_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_experimentally_resolved
=
config
.
model
.
heads
.
experimentally_resolved
def
run_experimentally_resolved_loss
(
value
,
batch
):
er_head
=
alphafold
.
model
.
modules
.
ExperimentallyResolvedHead
(
c_experimentally_resolved
,
config
.
model
.
global_config
)
return
er_head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_experimentally_resolved_loss
)
n_res
=
consts
.
n_res
value
=
{
"logits"
:
np
.
random
.
rand
(
n_res
,
37
).
astype
(
np
.
float32
),
}
batch
=
{
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"atom37_atom_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)),
"resolution"
:
np
.
array
(
1.0
)
"resolution"
:
np
.
array
(
1.0
)
,
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
with
torch
.
no_grad
():
out_repro
=
experimentally_resolved_loss
(
logits
=
value
[
"logits"
],
min_resolution
=
c_experimentally_resolved
.
min_resolution
,
max_resolution
=
c_experimentally_resolved
.
max_resolution
,
**
batch
,
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_supervised_chi_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_chi_loss
=
config
.
model
.
heads
.
structure_module
def
run_supervised_chi_loss
(
value
,
batch
):
ret
=
{
"loss"
:
jax
.
numpy
.
array
(
0.
),
"loss"
:
jax
.
numpy
.
array
(
0.
0
),
}
alphafold
.
model
.
folding
.
supervised_chi_loss
(
ret
,
batch
,
value
,
c_chi_loss
...
...
@@ -503,10 +496,12 @@ class TestLoss(unittest.TestCase):
value
=
{
"sidechains"
:
{
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
"unnormalized_angles_sin_cos"
:
np
.
random
.
rand
(
8
,
n_res
,
7
,
2
).
astype
(
np
.
float32
),
}
}
...
...
@@ -519,13 +514,9 @@ class TestLoss(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
value
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
value
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
x
:
torch
.
tensor
(
x
).
cuda
(),
batch
,
np
.
ndarray
)
batch
[
"chi_angles_sin_cos"
]
=
torch
.
stack
(
[
...
...
@@ -539,9 +530,9 @@ class TestLoss(unittest.TestCase):
out_repro
=
supervised_chi_loss
(
chi_weight
=
c_chi_loss
.
chi_weight
,
angle_norm_weight
=
c_chi_loss
.
angle_norm_weight
,
**
{
**
batch
,
**
value
[
"sidechains"
]}
)
**
{
**
batch
,
**
value
[
"sidechains"
]}
,
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
@@ -550,111 +541,119 @@ class TestLoss(unittest.TestCase):
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
def
run_viol_loss
(
batch
,
atom14_pred_pos
):
ret
=
{
"loss"
:
np
.
array
(
0.
).
astype
(
np
.
float32
),
"loss"
:
np
.
array
(
0.
0
).
astype
(
np
.
float32
),
}
value
=
{}
value
[
"violations"
]
=
(
alphafold
.
model
.
folding
.
find_structural_
violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
value
[
"
violations
"
]
=
alphafold
.
model
.
folding
.
find_structural_violations
(
batch
,
atom14_pred_pos
,
c_viol
,
)
alphafold
.
model
.
folding
.
structural_violation_loss
(
ret
,
batch
,
value
,
c_viol
,
ret
,
batch
,
value
,
c_viol
,
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_viol_loss
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
batch
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
batch
.
items
()}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
out_repro
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
**
batch
,
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_lddt_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_plddt
=
config
.
model
.
heads
.
predicted_lddt
def
run_plddt_loss
(
value
,
batch
):
head
=
alphafold
.
model
.
modules
.
PredictedLDDTHead
(
c_plddt
,
config
.
model
.
global_config
)
return
head
.
loss
(
value
,
batch
)
f
=
hk
.
transform
(
run_plddt_loss
)
n_res
=
consts
.
n_res
value
=
{
"predicted_lddt"
:
{
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
"logits"
:
np
.
random
.
rand
(
n_res
,
c_plddt
.
num_bins
).
astype
(
np
.
float32
),
},
"structure_module"
:
{
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
}
"final_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
},
}
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
[
"loss"
]))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
lddt_loss
(
logits
=
value
[
"predicted_lddt"
][
"logits"
],
all_atom_pred_pos
=
value
[
"structure_module"
][
"final_atom_positions"
],
**
{
**
batch
,
**
c_plddt
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_backbone_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_bb_loss
(
batch
,
value
):
ret
=
{
"loss"
:
np
.
array
(
0.
),
"loss"
:
np
.
array
(
0.
0
),
}
alphafold
.
model
.
folding
.
backbone_loss
(
ret
,
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
...
...
@@ -665,13 +664,19 @@ class TestLoss(unittest.TestCase):
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.
),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
}
value
=
{
"traj"
:
random_affines_vector
((
c_sm
.
num_layer
,
n_res
,)),
"traj"
:
random_affines_vector
(
(
c_sm
.
num_layer
,
n_res
,
)
),
}
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
...
...
@@ -695,6 +700,7 @@ class TestLoss(unittest.TestCase):
def
test_sidechain_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
def
run_sidechain_loss
(
batch
,
value
,
atom14_pred_positions
):
batch
=
{
**
batch
,
...
...
@@ -702,88 +708,94 @@ class TestLoss(unittest.TestCase):
batch
[
"aatype"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
)
)
,
}
v
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
(
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
v
[
"sidechains"
][
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
)
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
))
v
.
update
(
alphafold
.
model
.
folding
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
)
)
value
=
v
ret
=
alphafold
.
model
.
folding
.
sidechain_loss
(
batch
,
value
,
c_sm
)
return
ret
[
"loss"
]
f
=
hk
.
transform
(
run_sidechain_loss
)
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"aatype"
:
np
.
random
.
randint
(
0
,
20
,
(
n_res
,)),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
"atom14_gt_positions"
:
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom14_gt_exists"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
),
"all_atom_positions"
:
np
.
random
.
rand
(
n_res
,
37
,
3
).
astype
(
np
.
float32
),
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
37
)).
astype
(
np
.
float32
),
}
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_positions
(
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
batch
=
_build_extra_feats_np
()
value
=
{
"sidechains"
:
{
"frames"
:
random_affines_4x4
((
c_sm
.
num_layer
,
n_res
,
8
)),
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
float32
),
"atom_pos"
:
np
.
random
.
rand
(
c_sm
.
num_layer
,
n_res
,
14
,
3
).
astype
(
np
.
float32
),
}
}
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
,
atom14_pred_pos
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
atom14_pred_pos
=
to_tensor
(
atom14_pred_pos
)
batch
=
data_transforms
.
atom37_to_frames
(
batch
)
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
atom14_pred_pos
))
out_repro
=
sidechain_loss
(
sidechain_frames
=
value
[
"sidechains"
][
"frames"
],
sidechain_atom_pos
=
value
[
"sidechains"
][
"atom_pos"
],
**
{
**
batch
,
**
c_sm
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
def
run_tm_loss
(
representations
,
batch
,
value
):
head
=
alphafold
.
model
.
modules
.
PredictedAlignedErrorHead
(
c_tm
,
config
.
model
.
global_config
...
...
@@ -792,58 +804,58 @@ class TestLoss(unittest.TestCase):
v
.
update
(
value
)
v
[
"predicted_aligned_error"
]
=
head
(
representations
,
batch
,
False
)
return
head
.
loss
(
v
,
batch
)[
"loss"
]
f
=
hk
.
transform
(
run_tm_loss
)
n_res
=
consts
.
n_res
representations
=
{
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
"backbone_affine_tensor"
:
random_affines_vector
((
n_res
,)),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.
).
astype
(
np
.
float32
),
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"resolution"
:
np
.
array
(
1.0
).
astype
(
np
.
float32
),
}
value
=
{
"structure_module"
:
{
"final_affines"
:
random_affines_vector
((
n_res
,)),
}
}
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/predicted_aligned_error_head"
)
out_gt
=
f
.
apply
(
params
,
None
,
representations
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
to_tensor
=
lambda
n
:
torch
.
tensor
(
n
).
cuda
()
representations
=
tree_map
(
to_tensor
,
representations
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_affine_tensor"
]
=
(
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
batch
[
"backbone_affine_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
(
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
out_repro
=
tm_loss
(
logits
=
logits
,
final_affine_tensor
=
value
[
"structure_module"
][
"final_affines"
],
**
{
**
batch
,
**
c_tm
},
)
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
tests/test_model.py
View file @
07e64267
...
...
@@ -29,7 +29,7 @@ from tests.data_utils import (
random_extra_msa_feats
,
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -43,36 +43,29 @@ class TestModel(unittest.TestCase):
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
"model_1"
).
model
c
.
no_cycles
=
2
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
c
.
no_cycles
=
2
c
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
model
=
AlphaFold
(
c
)
batch
=
{}
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,)
)
tf
=
torch
.
randint
(
c
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
tf
,
c
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
(
(
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
)
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)
).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
make_atom14_masks
(
batch
))
add_recycling_dims
=
lambda
t
:
(
...
...
@@ -80,7 +73,7 @@ class TestModel(unittest.TestCase):
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
@@ -89,12 +82,14 @@ class TestModel(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
model
=
alphafold
.
model
.
modules
.
AlphaFold
(
config
.
model
)
return
model
(
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
batch
=
batch
,
is_training
=
False
,
return_representations
=
True
,
)
f
=
hk
.
transform
(
run_alphafold
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
''
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
...
...
@@ -107,14 +102,14 @@ class TestModel(unittest.TestCase):
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
out_gt
=
alphafold
.
model
.
all_atom
.
atom37_to_atom14
(
out_gt
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()
}
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
...
...
@@ -130,4 +125,3 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
tests/test_msa.py
View file @
07e64267
...
...
@@ -24,14 +24,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestMSARowAttentionWithPairBias
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
...
...
@@ -39,7 +39,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
c_z
=
consts
.
c_z
c
=
52
no_heads
=
4
chunk_size
=
None
chunk_size
=
None
mrapb
=
MSARowAttentionWithPairBias
(
c_m
,
c_z
,
c
,
no_heads
,
chunk_size
)
...
...
@@ -58,29 +58,26 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_row
=
alphafold
.
model
.
modules
.
MSARowAttentionWithPairBias
(
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
c_e
.
msa_row_attention_with_pair_bias
,
config
.
model
.
global_config
)
act
=
msa_row
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
,
pair_act
=
pair_act
)
return
act
f
=
hk
.
transform
(
run_msa_row_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_row_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
...
...
@@ -90,17 +87,21 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
...
...
@@ -124,47 +125,46 @@ class TestMSAColumnAttention(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
c_e
.
msa_column_attention
,
config
.
model
.
global_config
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_col_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
).
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"msa_column_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
...
...
@@ -188,40 +188,42 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
msa_col
=
alphafold
.
model
.
modules
.
MSAColumnGlobalAttention
(
c_e
.
msa_column_attention
,
config
.
model
.
global_config
,
name
=
"msa_column_global_attention"
c_e
.
msa_column_attention
,
config
.
model
.
global_config
,
name
=
"msa_column_global_attention"
,
)
act
=
msa_col
(
msa_act
=
msa_act
,
msa_mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_msa_col_global_att
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
c_e
=
consts
.
c_e
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_e
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
))
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+
"msa_column_global_attention"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
extra_msa_stack
.
stack
.
blocks
[
0
].
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
out_repro
=
(
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
tests/test_outer_product_mean.py
View file @
07e64267
...
...
@@ -19,7 +19,8 @@ from openfold.model.outer_product_mean import OuterProductMean
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()):
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -40,51 +41,54 @@ class TestOuterProductMean(unittest.TestCase):
m
=
opm
(
m
,
mask
)
self
.
assertTrue
(
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
m
.
shape
==
(
consts
.
batch_size
,
consts
.
n_res
,
consts
.
n_res
,
consts
.
c_z
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_opm_compare
(
self
):
def
test_opm_compare
(
self
):
def
run_opm
(
msa_act
,
msa_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_evo
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
opm
=
alphafold
.
model
.
modules
.
OuterProductMean
(
c_evo
.
outer_product_mean
,
c_evo
.
outer_product_mean
,
config
.
model
.
global_config
,
consts
.
c_z
,
)
act
=
opm
(
act
=
msa_act
,
mask
=
msa_mask
)
return
act
f
=
hk
.
transform
(
run_opm
)
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
c_m
=
consts
.
c_m
msa_act
=
np
.
random
.
rand
(
n_seq
,
n_res
,
c_m
).
astype
(
np
.
float32
)
*
100
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
)
.
astype
(
np
.
float32
)
msa_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
"alphafold/alphafold_iteration/evoformer/"
+
"evoformer_iteration/outer_product_mean"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
msa_act
,
msa_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
).
cpu
()
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
.
cpu
()
)
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
5e-4
))
...
...
tests/test_pair_transition.py
View file @
07e64267
...
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestPairTransition
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
n
=
4
...
...
@@ -50,42 +50,42 @@ class TestPairTransition(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
pt
=
alphafold
.
model
.
modules
.
Transition
(
c_e
.
pair_transition
,
c_e
.
pair_transition
,
config
.
model
.
global_config
,
name
=
"pair_transition"
name
=
"pair_transition"
,
)
act
=
pt
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_pair_transition
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
ones
((
n_res
,
n_res
)).
astype
(
np
.
float32
)
# no mask
pair_mask
=
np
.
ones
((
n_res
,
n_res
)).
astype
(
np
.
float32
)
# no mask
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
"pair_transition"
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
evoformer
.
blocks
[
0
].
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_structure_module.py
View file @
07e64267
...
...
@@ -23,7 +23,7 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
)
)
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModuleTransition
,
...
...
@@ -39,7 +39,7 @@ from tests.data_utils import (
random_affines_4x4
,
)
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -89,9 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
)
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -121,78 +119,70 @@ class TestStructureModule(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
folding
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
n_res
=
200
representations
=
{
'single'
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
'pair'
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
"single"
:
np
.
random
.
rand
(
n_res
,
consts
.
c_s
).
astype
(
np
.
float32
),
"pair"
:
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
),
}
batch
=
{
'
seq_mask
'
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
'
aatype
'
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"
seq_mask
"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"
aatype
"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
batch
[
'atom14_atom_exists'
]
=
np
.
take
(
restype_atom14_mask
,
batch
[
'aatype'
],
axis
=
0
batch
[
"atom14_atom_exists"
]
=
np
.
take
(
restype_atom14_mask
,
batch
[
"aatype"
],
axis
=
0
)
batch
[
'atom37_atom_exists'
]
=
np
.
take
(
restype_atom37_mask
,
batch
[
'aatype'
],
axis
=
0
batch
[
"atom37_atom_exists"
]
=
np
.
take
(
restype_atom37_mask
,
batch
[
"aatype"
],
axis
=
0
)
batch
.
update
(
make_atom14_masks_np
(
batch
))
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module"
)
key
=
jax
.
random
.
PRNGKey
(
42
)
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
out_gt
=
f
.
apply
(
params
,
key
,
representations
,
batch
)
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
[
"final_atom14_positions"
].
block_until_ready
())
)
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
(
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
)
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.01
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
n_res
=
3
c_in
=
5
bu
=
BackboneUpdate
(
c_in
)
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
...
...
@@ -237,25 +227,25 @@ class TestInvariantPointAttention(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
folding
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
)
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
,
)
return
attn
f
=
hk
.
transform
(
run_ipa
)
n_res
=
consts
.
n_res
c_s
=
consts
.
c_s
c_z
=
consts
.
c_z
sample_act
=
np
.
random
.
rand
(
n_res
,
c_s
)
sample_2d
=
np
.
random
.
rand
(
n_res
,
n_res
,
c_z
)
sample_mask
=
np
.
ones
((
n_res
,
1
))
...
...
@@ -263,15 +253,13 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
+
"fold_iteration/invariant_point_attention"
"alphafold/alphafold_iteration/structure_module/"
+
"fold_iteration/invariant_point_attention"
)
out_gt
=
f
.
apply
(
...
...
@@ -282,17 +270,17 @@ class TestInvariantPointAttention(unittest.TestCase):
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
.
ipa
(
torch
.
as_tensor
(
sample_act
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_2d
).
float
().
cuda
(),
transformations
,
torch
.
as_tensor
(
sample_act
).
float
().
cuda
(),
torch
.
as_tensor
(
sample_2d
).
float
().
cuda
(),
transformations
,
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
2
n
=
3
c_s
=
13
...
...
@@ -300,7 +288,7 @@ class TestAngleResnet(unittest.TestCase):
no_layers
=
5
no_angles
=
7
epsilon
=
1e-12
ar
=
AngleResnet
(
c_s
,
c_hidden
,
no_layers
,
no_angles
,
epsilon
)
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
...
tests/test_template.py
View file @
07e64267
...
...
@@ -24,14 +24,14 @@ import tests.compare_utils as compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestTemplatePointwiseAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
c_t
=
consts
.
c_t
...
...
@@ -40,7 +40,7 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
no_heads
=
13
n_res
=
consts
.
n_res
inf
=
1e7
tpa
=
TemplatePointwiseAttention
(
c_t
,
c_z
,
c
,
no_heads
,
chunk_size
=
4
,
inf
=
inf
)
...
...
@@ -67,8 +67,8 @@ class TestTemplatePairStack(unittest.TestCase):
n_res
=
consts
.
n_res
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
eps
=
1e-7
inf
=
1e7
eps
=
1e-7
tpe
=
TemplatePairStack
(
c_t
,
...
...
@@ -98,45 +98,47 @@ class TestTemplatePairStack(unittest.TestCase):
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
return
act
f
=
hk
.
transform
(
run_template_pair_stack
)
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_t
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
)
)
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/output_layer_norm"
))
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
_mask_trans
=
False
,
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
@@ -146,46 +148,46 @@ class Template(unittest.TestCase):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
f
=
hk
.
transform
(
test_template_embedding
)
n_res
=
consts
.
n_res
n_templ
=
consts
.
n_templ
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding"
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_triangular_attention.py
View file @
07e64267
...
...
@@ -21,7 +21,7 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
...
...
@@ -34,12 +34,7 @@ class TestTriangularAttention(unittest.TestCase):
no_heads
=
4
starting
=
True
tan
=
TriangleAttention
(
c_z
,
c
,
no_heads
,
starting
)
tan
=
TriangleAttention
(
c_z
,
c
,
no_heads
,
starting
)
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
...
...
@@ -53,22 +48,24 @@ class TestTriangularAttention(unittest.TestCase):
def
_tri_att_compare
(
self
,
starting
=
False
):
name
=
(
"triangle_attention_"
+
(
"starting"
if
starting
else
"ending"
)
+
"_node"
"triangle_attention_"
+
(
"starting"
if
starting
else
"ending"
)
+
"_node"
)
def
run_tri_att
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_att
=
alphafold
.
model
.
modules
.
TriangleAttention
(
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_ending_node
,
c_e
.
triangle_attention_starting_node
if
starting
else
c_e
.
triangle_attention_ending_node
,
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_att
(
pair_act
=
pair_act
,
pair_mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_att
)
n_res
=
consts
.
n_res
...
...
@@ -78,24 +75,23 @@ class TestTriangularAttention(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
model
.
evoformer
.
blocks
[
0
].
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -110,4 +106,4 @@ class TestTriangularAttention(unittest.TestCase):
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_triangular_multiplicative_update.py
View file @
07e64267
...
...
@@ -20,14 +20,14 @@ from openfold.utils.tensor_utils import tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
(
compare_utils
.
alphafold_is_installed
()
)
:
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
class
TestTriangularMultiplicativeUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c
=
11
outgoing
=
True
...
...
@@ -50,22 +50,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self
.
assertTrue
(
shape_before
==
shape_after
)
def
_tri_mul_compare
(
self
,
incoming
=
False
):
name
=
(
"triangle_multiplication_"
+
(
"incoming"
if
incoming
else
"outgoing"
)
name
=
"triangle_multiplication_"
+
(
"incoming"
if
incoming
else
"outgoing"
)
def
run_tri_mul
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_e
=
config
.
model
.
embeddings_and_evoformer
.
evoformer
tri_mul
=
alphafold
.
model
.
modules
.
TriangleMultiplication
(
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
c_e
.
triangle_multiplication_incoming
if
incoming
else
c_e
.
triangle_multiplication_outgoing
,
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_mul
)
n_res
=
consts
.
n_res
...
...
@@ -76,24 +77,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
DeviceArray
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
...
@@ -109,4 +109,3 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_utils.py
View file @
07e64267
...
...
@@ -20,17 +20,21 @@ from openfold.utils.affine_utils import T, quat_to_rot
from
openfold.utils.tensor_utils
import
chunk_layer
X_90_ROT
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
0
,
0
,
-
1
],
[
0
,
1
,
0
],
])
X_NEG_90_ROT
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
0
,
-
1
,
0
],
])
X_90_ROT
=
torch
.
tensor
(
[
[
1
,
0
,
0
],
[
0
,
0
,
-
1
],
[
0
,
1
,
0
],
]
)
X_NEG_90_ROT
=
torch
.
tensor
(
[
[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
0
,
-
1
,
0
],
]
)
class
TestAffineT
(
unittest
.
TestCase
):
...
...
@@ -53,7 +57,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
transf
=
[
[
1
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
2
],
[
0
,
0
,
-
1
,
2
],
[
0
,
1
,
0
,
3
],
[
0
,
0
,
0
,
1
],
]
...
...
@@ -62,10 +66,7 @@ class TestAffineT(unittest.TestCase):
true_rot
=
transf
[:
3
,
:
3
]
true_trans
=
transf
[:
3
,
3
]
transf
=
torch
.
stack
(
[
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from_4x4
(
transf
)
...
...
@@ -78,8 +79,7 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
...
...
@@ -88,12 +88,11 @@ class TestAffineT(unittest.TestCase):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
)
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
0
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
...
...
@@ -124,7 +123,7 @@ class TestAffineT(unittest.TestCase):
x
=
torch
.
arange
(
30
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
x
=
x
.
view
(
2
,
-
1
,
3
)
# [2, 10, 3]
x
=
x
.
view
(
2
,
-
1
,
3
)
# [2, 10, 3]
pts
=
t
[...,
None
].
apply
(
x
)
...
...
@@ -165,4 +164,4 @@ class TestAffineT(unittest.TestCase):
self
.
assertTrue
(
torch
.
all
(
chunked
[
"out"
]
==
unchunked
[
"out"
]))
self
.
assertTrue
(
torch
.
all
(
chunked
[
"inner"
][
"out"
]
==
unchunked
[
"inner"
][
"out"
])
)
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment