Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e3daf724
Commit
e3daf724
authored
Jan 03, 2022
by
Gustaf Ahdritz
Browse files
Overhaul transformation code for better parity w/ AlphaFold
parent
1f709b0d
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1669 additions
and
178 deletions
+1669
-178
openfold/config.py
openfold/config.py
+4
-4
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+18
-12
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+4
-2
openfold/data/templates.py
openfold/data/templates.py
+14
-3
openfold/model/embedders.py
openfold/model/embedders.py
+1
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+59
-51
openfold/utils/feats.py
openfold/utils/feats.py
+18
-18
openfold/utils/loss.py
openfold/utils/loss.py
+35
-22
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+1380
-0
tests/test_data/alphafold/common/stereo_chemical_props.txt
tests/test_data/alphafold/common/stereo_chemical_props.txt
+1
-0
tests/test_feats.py
tests/test_feats.py
+10
-6
tests/test_loss.py
tests/test_loss.py
+18
-10
tests/test_model.py
tests/test_model.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+9
-23
tests/test_utils.py
tests/test_utils.py
+96
-26
No files found.
openfold/config.py
View file @
e3daf724
...
...
@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists"
:
[
NUM_RES
,
None
],
"atom14_gt_positions"
:
[
NUM_RES
,
None
,
None
],
"atom37_atom_exists"
:
[
NUM_RES
,
None
],
"backbone_
affine
_mask"
:
[
NUM_RES
],
"backbone_
affine
_tensor"
:
[
NUM_RES
,
None
,
None
],
"backbone_
rigid
_mask"
:
[
NUM_RES
],
"backbone_
rigid
_tensor"
:
[
NUM_RES
,
None
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
...
...
@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
"template_backbone_
affine
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
affine
_tensor"
:
[
"template_backbone_
rigid
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
rigid
_tensor"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
"template_mask"
:
[
NUM_TEMPLATES
],
...
...
openfold/data/data_transforms.py
View file @
e3daf724
...
...
@@ -22,7 +22,7 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return
protein
def
atom37_to_frames
(
protein
):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
...
...
@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
1e-8
,
eps
=
eps
,
)
group_exists
=
batched_gather
(
...
...
@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
...
...
@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
gt_frames_tensor
=
gt_frames
.
to_
tensor_
4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_
tensor_
4x4
()
protein
[
"rigidgroups_gt_frames"
]
=
gt_frames_tensor
protein
[
"rigidgroups_gt_exists"
]
=
gt_exists
...
...
@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim
=-
1
,
)
torsion_frames
=
T
.
from_3_points
(
torsion_frames
=
Rigid
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
...
...
@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def
get_backbone_frames
(
protein
):
#
TODO: Verify that this is correct
protein
[
"backbone_
affine
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
#
DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein
[
"backbone_
rigid
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
...,
0
,
:,
:
]
protein
[
"backbone_
affine
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
protein
[
"backbone_
rigid
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
return
protein
...
...
openfold/data/mmcif_parsing.py
View file @
e3daf724
...
...
@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
zero_center
:
bool
=
True
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
_zero_center_positions
:
bool
=
True
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
...
...
@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
if
zero_center
:
if
_
zero_center
_positions
:
binary_mask
=
all_atom_mask
.
astype
(
bool
)
translation_vec
=
all_atom_positions
[
binary_mask
].
mean
(
axis
=
0
)
all_atom_positions
[
binary_mask
]
-=
translation_vec
...
...
openfold/data/templates.py
View file @
e3daf724
...
...
@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
,
_zero_center_positions
=
_zero_center_positions
,
)
all_atom_positions
,
all_atom_mask
=
coords_with_mask
_check_residue_distances
(
...
...
@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence
:
str
,
template_chain_id
:
str
,
kalign_binary_path
:
str
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
"""Parses atom positions in the target structure and aligns with the query.
...
...
@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
,
_zero_center_positions
=
_zero_center_positions
,
)
except
(
CaDistanceError
,
KeyError
)
as
ex
:
raise
NoAtomDataInTemplateError
(
...
...
@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs
:
Mapping
[
str
,
str
],
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
_zero_center_positions
:
bool
=
True
,
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
...
...
@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence
=
query_sequence
,
template_chain_id
=
hit_chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
...
...
@@ -913,7 +922,6 @@ class TemplateSearchResult:
class
TemplateHitFeaturizer
:
"""A class for turning hhr hits to template features."""
def
__init__
(
self
,
mmcif_dir
:
str
,
...
...
@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path
:
Optional
[
str
]
=
None
,
strict_error_check
:
bool
=
False
,
_shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
_zero_center_positions
:
bool
=
True
,
):
"""Initializes the Template Search.
...
...
@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self
.
_obsolete_pdbs
=
{}
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
def
get_templates
(
self
,
...
...
@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
_zero_center_positions
=
self
.
_zero_center_positions
,
)
if
result
.
error
:
...
...
openfold/model/embedders.py
View file @
e3daf724
...
...
@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
# [*, N, C_m]
...
...
openfold/model/structure_module.py
View file @
e3daf724
...
...
@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
permute_final_dims
,
...
...
@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
t
:
T
,
r
:
Rigid
,
mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t
:
[*, N_res]
affine
transformation object
r
:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
...
...
@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
t
[...,
None
].
apply
(
q_pts
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
...
...
@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
t
[...,
None
].
apply
(
kv_pts
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
...
...
@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
...
...
@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class
BackboneUpdate
(
nn
.
Module
):
"""
Implements Algorithm 23.
Implements
part of
Algorithm 23.
"""
def
__init__
(
self
,
c_s
):
...
...
@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self
.
linear
=
Linear
(
self
.
c_s
,
6
,
init
=
"final"
)
def
forward
(
self
,
s
)
:
def
forward
(
self
,
s
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res
] affine transformation object
[*, N_res
, 6] update vector
"""
# [*, 6]
params
=
self
.
linear
(
s
)
# [*, 3]
quats
,
trans
=
params
[...,
:
3
],
params
[...,
3
:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# [*, 3]
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,)
)
# [*, 4]
quats
=
torch
.
cat
([
ones
,
quats
],
dim
=-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
update
=
self
.
linear
(
s
)
# [*, 3, 3]
rots
=
quat_to_rot
(
quats
)
return
T
(
rots
,
trans
)
return
update
class
StructureModuleTransitionLayer
(
nn
.
Module
):
...
...
@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self
,
s
,
z
,
f
,
aatype
,
mask
=
None
,
):
"""
...
...
@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f
:
aatype
:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
...
...
@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s
=
self
.
linear_in
(
s
)
# [*, N]
t
=
T
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
)
rigids
=
Rigid
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
,
fmt
=
"quat"
,
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
t
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
t
=
t
.
compose
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global
=
Rigid
(
Rotation
(
rot_mats
=
rigids
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
rigids
.
get_trans
(),
)
backb_to_global
=
backb_to_global
.
scale_translation
(
self
.
trans_scale_factor
)
# [*, N, 7, 2]
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
ngles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
)
,
a
,
f
,
backb_to_global
,
a
ngles
,
aatype
,
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
all_frames_to_global
,
f
,
aatype
,
)
scaled_rigids
=
rigids
.
scale_translation
(
self
.
trans_scale_factor
)
preds
=
{
"frames"
:
t
.
scale
_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_4x4
(),
"unnormalized_angles"
:
unnormalized_a
,
"angles"
:
a
,
"frames"
:
scale
d_rigids
.
to_tensor_7
(),
"sidechain_frames"
:
all_frames_to_global
.
to_
tensor_
4x4
(),
"unnormalized_angles"
:
unnormalized_a
ngles
,
"angles"
:
a
ngles
,
"positions"
:
pred_xyz
,
}
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
t
=
t
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
...
...
@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
f
# [*, N, 8] # [*, N]
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
r
.
get_
rots
()
.
dtype
,
r
.
get_
rots
()
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
t
,
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
...
...
openfold/utils/feats.py
View file @
e3daf724
...
...
@@ -22,7 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
one_hot
,
...
...
@@ -124,18 +124,16 @@ def build_template_pair_feat(
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
# TODO: Consider running this in double precision
affines
=
T
.
make_transform_from_reference
(
rigids
=
Rigid
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
eps
=
eps
,
)
points
=
rigids
.
get_trans
()[...,
None
,
:,
:]
rigid_vec
=
rigids
[...,
None
].
invert_apply
(
points
)
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
))
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
rigid_vec
**
2
,
dim
=-
1
))
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
...
...
@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
affine
_vec
*
inv_distance_scalar
[...,
None
]
unit_vector
=
rigid
_vec
*
inv_distance_scalar
[...,
None
]
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
...
@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def
torsion_angles_to_frames
(
t
:
T
,
r
:
Rigid
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
...
...
@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_
t
=
T
.
from_4x4
(
default_4x4
)
default_
r
=
r
.
from_
tensor_
4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
([
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
...
...
@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_
t
.
rots
.
shape
)
all_rots
=
alpha
.
new_zeros
(
default_
r
.
get_rots
().
get_rot_mats
()
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
)
,
None
)
all_frames
=
default_
t
.
compose
(
all_rots
)
all_frames
=
default_
r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
...
@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
con
cat
(
all_frames_to_bb
=
Rigid
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
...
@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
r
:
Rigid
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
...
...
@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
t_atoms_to_global
=
r
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
...
...
openfold/utils/loss.py
View file @
e3daf724
...
...
@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -74,8 +74,8 @@ def torsion_angle_loss(
def
compute_fape
(
pred_frames
:
T
,
target_frames
:
T
,
pred_frames
:
Rigid
,
target_frames
:
Rigid
,
frames_mask
:
torch
.
Tensor
,
pred_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
...
...
@@ -111,7 +111,7 @@ def compute_fape(
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
...
...
@@ -123,8 +123,8 @@ def compute_fape(
def
backbone_loss
(
backbone_
affine
_tensor
:
torch
.
Tensor
,
backbone_
affine
_mask
:
torch
.
Tensor
,
backbone_
rigid
_tensor
:
torch
.
Tensor
,
backbone_
rigid
_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
...
...
@@ -132,16 +132,27 @@ def backbone_loss(
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from
_4x4
(
sidechain_frames
)
sidechain_frames
=
Rigid
.
from_tensor
_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from
_4x4
(
renamed_gt_frames
)
renamed_gt_frames
=
Rigid
.
from_tensor
_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
...
...
@@ -422,7 +433,7 @@ def distogram_loss(
device
=
logits
.
device
,
)
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
...
...
@@ -550,8 +561,8 @@ def compute_tm(
def
tm_loss
(
logits
,
final_affine_tensor
,
backbone_
affine
_tensor
,
backbone_
affine
_mask
,
backbone_
rigid
_tensor
,
backbone_
rigid
_mask
,
resolution
,
max_bin
=
31
,
no_bins
=
64
,
...
...
@@ -560,16 +571,17 @@ def tm_loss(
eps
=
1e-8
,
**
kwargs
,
):
pred_affine
=
T
.
from_
4x4
(
final_affine_tensor
)
backbone_
affine
=
T
.
from
_4x4
(
backbone_
affine
_tensor
)
pred_affine
=
Rigid
.
from_
tensor_7
(
final_affine_tensor
)
backbone_
rigid
=
Rigid
.
from_tensor
_4x4
(
backbone_
rigid
_tensor
)
def
_points
(
affine
):
pts
=
affine
.
get_trans
()[...,
None
,
:,
:]
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_
affine
))
**
2
,
dim
=-
1
(
_points
(
pred_affine
)
-
_points
(
backbone_
rigid
))
**
2
,
dim
=-
1
)
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
...
...
@@ -583,7 +595,7 @@ def tm_loss(
)
square_mask
=
(
backbone_
affine
_mask
[...,
None
]
*
backbone_
affine
_mask
[...,
None
,
:]
backbone_
rigid
_mask
[...,
None
]
*
backbone_
rigid
_mask
[...,
None
,
:]
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
...
...
@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
),
}
cum_loss
=
0
cum_loss
=
0
.
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
...
...
openfold/utils/
affine
_utils.py
→
openfold/utils/
rigid
_utils.py
View file @
e3daf724
This diff is collapsed.
Click to expand it.
tests/test_data/alphafold/common/stereo_chemical_props.txt
0 → 120000
View file @
e3daf724
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
tests/test_feats.py
View file @
e3daf724
...
...
@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
...
@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
...
@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
transforms_repro
=
out
.
to_
tensor_
4x4
().
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
...
...
@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
...
@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
...
...
tests/test_loss.py
View file @
e3daf724
...
...
@@ -20,7 +20,10 @@ import unittest
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
torsion_angle_loss
,
...
...
@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import
haiku
as
hk
def
affine_vector_to_4x4
(
affine
):
r
=
Rigid
.
from_tensor_7
(
affine
)
return
r
.
to_tensor_4x4
()
class
TestLoss
(
unittest
.
TestCase
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
...
...
@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
trans_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
t
=
T
(
rots
,
trans
)
t_gt
=
T
(
rots_gt
,
trans_gt
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
t_gt
=
Rigid
(
Rotation
(
rot_mats
=
rots_gt
)
,
trans_gt
)
frames_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_frames
)).
float
()
positions_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_atoms
)).
float
()
length_scale
=
10
...
...
@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"traj"
]
=
affine_vector_to_4x4
(
value
[
"traj
"
]
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask
"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
...
...
@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f
=
hk
.
transform
(
run_tm_loss
)
np
.
random
.
seed
(
42
)
n_res
=
consts
.
n_res
representations
=
{
...
...
@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
...
...
tests/test_model.py
View file @
e3daf724
...
...
@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_structure_module.py
View file @
e3daf724
...
...
@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet
,
InvariantPointAttention
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
n_res
=
3
c_in
=
5
bu
=
BackboneUpdate
(
c_in
)
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
t
=
bu
(
s
)
rot
,
tra
=
t
.
rots
,
t
.
trans
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
tra
.
shape
==
(
batch_size
,
n_res
,
3
))
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
c_m
=
13
...
...
@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
(
rots
,
trans
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
)
shape_before
=
s
.
shape
s
=
ipa
(
s
,
z
,
t
,
mask
)
s
=
ipa
(
s
,
z
,
r
,
mask
)
self
.
assertTrue
(
s
.
shape
==
shape_before
)
...
...
@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
...
...
tests/test_utils.py
View file @
e3daf724
...
...
@@ -13,11 +13,24 @@
# limitations under the License.
import
math
import
numpy
as
np
import
torch
import
unittest
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
quat_to_rot
,
rot_to_quat
,
)
from
openfold.utils.tensor_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
X_90_ROT
=
torch
.
tensor
(
...
...
@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class
TestUtils
(
unittest
.
TestCase
):
def
test_
T
_from_3_points_shape
(
self
):
def
test_
rigid
_from_3_points_shape
(
self
):
batch_size
=
2
n_res
=
5
...
...
@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
x3
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
.
from_3_points
(
x1
,
x2
,
x3
)
r
=
Rigid
.
from_3_points
(
x1
,
x2
,
x3
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
torch
.
all
(
tra
==
x2
))
def
test_
T
_from_4x4
(
self
):
def
test_
rigid
_from_4x4
(
self
):
batch_size
=
2
transf
=
[
[
1
,
0
,
0
,
1
],
...
...
@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from
_4x4
(
transf
)
r
=
Rigid
.
from_tensor
_4x4
(
transf
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
torch
.
all
(
rot
==
true_rot
.
unsqueeze
(
0
)))
self
.
assertTrue
(
torch
.
all
(
tra
==
true_trans
.
unsqueeze
(
0
)))
def
test_
T
_shape
(
self
):
def
test_
rigid
_shape
(
self
):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
transf
=
Rigid
(
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
def
test_
T_con
cat
(
self
):
def
test_
rigid_
cat
(
self
):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
transf
=
Rigid
(
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
transf_
con
cat
=
T
.
con
cat
([
transf
,
transf
],
dim
=
0
)
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
0
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
transf_rots
=
transf
.
get_rots
().
get_rot_mats
()
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
)
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
1
)
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
rots
[:,
:
n
]
==
transf
.
rots
))
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
trans
[:,
:
n
]
==
transf
.
trans
))
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
def
test_T_compose
(
self
):
self
.
assertTrue
(
torch
.
all
(
transf_cat_rots
[:,
:
n
]
==
transf_rots
))
self
.
assertTrue
(
torch
.
all
(
transf_cat
.
get_trans
()[:,
:
n
]
==
transf
.
get_trans
())
)
def
test_rigid_compose
(
self
):
trans_1
=
[
0
,
1
,
0
]
trans_2
=
[
0
,
0
,
1
]
t1
=
T
(
X_90_ROT
,
torch
.
tensor
(
trans_1
))
t2
=
T
(
X_NEG_90_ROT
,
torch
.
tensor
(
trans_2
))
r
=
Rotation
(
rot_mats
=
X_90_ROT
)
t
=
torch
.
tensor
(
trans_1
)
t1
=
Rigid
(
Rotation
(
rot_mats
=
X_90_ROT
),
torch
.
tensor
(
trans_1
)
)
t2
=
Rigid
(
Rotation
(
rot_mats
=
X_NEG_90_ROT
),
torch
.
tensor
(
trans_2
)
)
t3
=
t1
.
compose
(
t2
)
self
.
assertTrue
(
torch
.
all
(
t3
.
rots
==
torch
.
eye
(
3
)))
self
.
assertTrue
(
torch
.
all
(
t3
.
trans
==
0
))
self
.
assertTrue
(
torch
.
all
(
t3
.
get_rots
().
get_rot_mats
()
==
torch
.
eye
(
3
))
)
self
.
assertTrue
(
torch
.
all
(
t3
.
get_trans
()
==
0
)
)
def
test_
T
_apply
(
self
):
def
test_
rigid
_apply
(
self
):
rots
=
torch
.
stack
([
X_90_ROT
,
X_NEG_90_ROT
],
dim
=
0
)
trans
=
torch
.
tensor
([
1
,
1
,
1
])
trans
=
torch
.
stack
([
trans
,
trans
],
dim
=
0
)
t
=
T
(
rots
,
trans
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
x
=
torch
.
arange
(
30
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
...
...
@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps
=
1e-07
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
rot
-
X_90_ROT
)
<
eps
))
def
test_rot_to_quat
(
self
):
quat
=
rot_to_quat
(
X_90_ROT
)
eps
=
1e-07
ans
=
torch
.
tensor
([
math
.
sqrt
(
0.5
),
math
.
sqrt
(
0.5
),
0.
,
0.
])
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
quat
-
ans
)
<
eps
))
def
test_chunk_layer_tensor
(
self
):
x
=
torch
.
rand
(
2
,
4
,
5
,
15
)
l
=
torch
.
nn
.
Linear
(
15
,
30
)
...
...
@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened
=
x_flat
[
i
:
j
]
self
.
assertTrue
(
torch
.
all
(
chunked
==
chunked_flattened
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pre_compose_compare
(
self
):
quat
=
np
.
random
.
rand
(
20
,
4
)
trans
=
[
np
.
random
.
rand
(
20
)
for
_
in
range
(
3
)]
quat_affine
=
alphafold
.
model
.
quat_affine
.
QuatAffine
(
quat
,
translation
=
trans
)
update_vec
=
np
.
random
.
rand
(
20
,
6
)
new_gt
=
quat_affine
.
pre_compose
(
update_vec
)
quat_t
=
torch
.
tensor
(
quat
)
trans_t
=
torch
.
stack
([
torch
.
tensor
(
t
)
for
t
in
trans
],
dim
=-
1
)
rigid
=
Rigid
(
Rotation
(
quats
=
quat_t
),
trans_t
)
new_repro
=
rigid
.
compose_q_update_vec
(
torch
.
tensor
(
update_vec
))
new_gt_q
=
torch
.
tensor
(
np
.
array
(
new_gt
.
quaternion
))
new_gt_t
=
torch
.
stack
(
[
torch
.
tensor
(
np
.
array
(
t
))
for
t
in
new_gt
.
translation
],
dim
=-
1
)
new_repro_q
=
new_repro
.
get_rots
().
get_quats
()
new_repro_t
=
new_repro
.
get_trans
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_q
-
new_repro_q
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_t
-
new_repro_t
))
<
consts
.
eps
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment