Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e3daf724
Commit
e3daf724
authored
Jan 03, 2022
by
Gustaf Ahdritz
Browse files
Overhaul transformation code for better parity w/ AlphaFold
parent
1f709b0d
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1669 additions
and
178 deletions
+1669
-178
openfold/config.py
openfold/config.py
+4
-4
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+18
-12
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+4
-2
openfold/data/templates.py
openfold/data/templates.py
+14
-3
openfold/model/embedders.py
openfold/model/embedders.py
+1
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+59
-51
openfold/utils/feats.py
openfold/utils/feats.py
+18
-18
openfold/utils/loss.py
openfold/utils/loss.py
+35
-22
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+1380
-0
tests/test_data/alphafold/common/stereo_chemical_props.txt
tests/test_data/alphafold/common/stereo_chemical_props.txt
+1
-0
tests/test_feats.py
tests/test_feats.py
+10
-6
tests/test_loss.py
tests/test_loss.py
+18
-10
tests/test_model.py
tests/test_model.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+9
-23
tests/test_utils.py
tests/test_utils.py
+96
-26
No files found.
openfold/config.py
View file @
e3daf724
...
...
@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists"
:
[
NUM_RES
,
None
],
"atom14_gt_positions"
:
[
NUM_RES
,
None
,
None
],
"atom37_atom_exists"
:
[
NUM_RES
,
None
],
"backbone_
affine
_mask"
:
[
NUM_RES
],
"backbone_
affine
_tensor"
:
[
NUM_RES
,
None
,
None
],
"backbone_
rigid
_mask"
:
[
NUM_RES
],
"backbone_
rigid
_tensor"
:
[
NUM_RES
,
None
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
...
...
@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
"template_backbone_
affine
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
affine
_tensor"
:
[
"template_backbone_
rigid
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
rigid
_tensor"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
"template_mask"
:
[
NUM_TEMPLATES
],
...
...
openfold/data/data_transforms.py
View file @
e3daf724
...
...
@@ -22,7 +22,7 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return
protein
def
atom37_to_frames
(
protein
):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
...
...
@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
1e-8
,
eps
=
eps
,
)
group_exists
=
batched_gather
(
...
...
@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
...
...
@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
gt_frames_tensor
=
gt_frames
.
to_
tensor_
4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_
tensor_
4x4
()
protein
[
"rigidgroups_gt_frames"
]
=
gt_frames_tensor
protein
[
"rigidgroups_gt_exists"
]
=
gt_exists
...
...
@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim
=-
1
,
)
torsion_frames
=
T
.
from_3_points
(
torsion_frames
=
Rigid
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
...
...
@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def
get_backbone_frames
(
protein
):
#
TODO: Verify that this is correct
protein
[
"backbone_
affine
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
#
DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein
[
"backbone_
rigid
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
...,
0
,
:,
:
]
protein
[
"backbone_
affine
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
protein
[
"backbone_
rigid
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
return
protein
...
...
openfold/data/mmcif_parsing.py
View file @
e3daf724
...
...
@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
zero_center
:
bool
=
True
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
_zero_center_positions
:
bool
=
True
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
...
...
@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
if
zero_center
:
if
_
zero_center
_positions
:
binary_mask
=
all_atom_mask
.
astype
(
bool
)
translation_vec
=
all_atom_positions
[
binary_mask
].
mean
(
axis
=
0
)
all_atom_positions
[
binary_mask
]
-=
translation_vec
...
...
openfold/data/templates.py
View file @
e3daf724
...
...
@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
,
_zero_center_positions
=
_zero_center_positions
,
)
all_atom_positions
,
all_atom_mask
=
coords_with_mask
_check_residue_distances
(
...
...
@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence
:
str
,
template_chain_id
:
str
,
kalign_binary_path
:
str
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
"""Parses atom positions in the target structure and aligns with the query.
...
...
@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
,
_zero_center_positions
=
_zero_center_positions
,
)
except
(
CaDistanceError
,
KeyError
)
as
ex
:
raise
NoAtomDataInTemplateError
(
...
...
@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs
:
Mapping
[
str
,
str
],
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
_zero_center_positions
:
bool
=
True
,
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
...
...
@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence
=
query_sequence
,
template_chain_id
=
hit_chain_id
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
...
...
@@ -913,7 +922,6 @@ class TemplateSearchResult:
class
TemplateHitFeaturizer
:
"""A class for turning hhr hits to template features."""
def
__init__
(
self
,
mmcif_dir
:
str
,
...
...
@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path
:
Optional
[
str
]
=
None
,
strict_error_check
:
bool
=
False
,
_shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
_zero_center_positions
:
bool
=
True
,
):
"""Initializes the Template Search.
...
...
@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self
.
_obsolete_pdbs
=
{}
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
def
get_templates
(
self
,
...
...
@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
_zero_center_positions
=
self
.
_zero_center_positions
,
)
if
result
.
error
:
...
...
openfold/model/embedders.py
View file @
e3daf724
...
...
@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
# [*, N, C_m]
...
...
openfold/model/structure_module.py
View file @
e3daf724
...
...
@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
permute_final_dims
,
...
...
@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
t
:
T
,
r
:
Rigid
,
mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t
:
[*, N_res]
affine
transformation object
r
:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
...
...
@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
t
[...,
None
].
apply
(
q_pts
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
...
...
@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
t
[...,
None
].
apply
(
kv_pts
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
...
...
@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
...
...
@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class
BackboneUpdate
(
nn
.
Module
):
"""
Implements Algorithm 23.
Implements
part of
Algorithm 23.
"""
def
__init__
(
self
,
c_s
):
...
...
@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self
.
linear
=
Linear
(
self
.
c_s
,
6
,
init
=
"final"
)
def
forward
(
self
,
s
)
:
def
forward
(
self
,
s
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res
] affine transformation object
[*, N_res
, 6] update vector
"""
# [*, 6]
params
=
self
.
linear
(
s
)
# [*, 3]
quats
,
trans
=
params
[...,
:
3
],
params
[...,
3
:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# [*, 3]
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,)
)
# [*, 4]
quats
=
torch
.
cat
([
ones
,
quats
],
dim
=-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
update
=
self
.
linear
(
s
)
# [*, 3, 3]
rots
=
quat_to_rot
(
quats
)
return
T
(
rots
,
trans
)
return
update
class
StructureModuleTransitionLayer
(
nn
.
Module
):
...
...
@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self
,
s
,
z
,
f
,
aatype
,
mask
=
None
,
):
"""
...
...
@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f
:
aatype
:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
...
...
@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s
=
self
.
linear_in
(
s
)
# [*, N]
t
=
T
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
)
rigids
=
Rigid
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
,
fmt
=
"quat"
,
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
t
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
t
=
t
.
compose
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global
=
Rigid
(
Rotation
(
rot_mats
=
rigids
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
rigids
.
get_trans
(),
)
backb_to_global
=
backb_to_global
.
scale_translation
(
self
.
trans_scale_factor
)
# [*, N, 7, 2]
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
ngles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
)
,
a
,
f
,
backb_to_global
,
a
ngles
,
aatype
,
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
all_frames_to_global
,
f
,
aatype
,
)
scaled_rigids
=
rigids
.
scale_translation
(
self
.
trans_scale_factor
)
preds
=
{
"frames"
:
t
.
scale
_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_4x4
(),
"unnormalized_angles"
:
unnormalized_a
,
"angles"
:
a
,
"frames"
:
scale
d_rigids
.
to_tensor_7
(),
"sidechain_frames"
:
all_frames_to_global
.
to_
tensor_
4x4
(),
"unnormalized_angles"
:
unnormalized_a
ngles
,
"angles"
:
a
ngles
,
"positions"
:
pred_xyz
,
}
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
t
=
t
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
...
...
@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
f
# [*, N, 8] # [*, N]
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
r
.
get_
rots
()
.
dtype
,
r
.
get_
rots
()
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
t
,
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
...
...
openfold/utils/feats.py
View file @
e3daf724
...
...
@@ -22,7 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
one_hot
,
...
...
@@ -124,18 +124,16 @@ def build_template_pair_feat(
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
# TODO: Consider running this in double precision
affines
=
T
.
make_transform_from_reference
(
rigids
=
Rigid
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
eps
=
eps
,
)
points
=
rigids
.
get_trans
()[...,
None
,
:,
:]
rigid_vec
=
rigids
[...,
None
].
invert_apply
(
points
)
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
))
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
rigid_vec
**
2
,
dim
=-
1
))
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
...
...
@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
affine
_vec
*
inv_distance_scalar
[...,
None
]
unit_vector
=
rigid
_vec
*
inv_distance_scalar
[...,
None
]
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
...
@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def
torsion_angles_to_frames
(
t
:
T
,
r
:
Rigid
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
...
...
@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_
t
=
T
.
from_4x4
(
default_4x4
)
default_
r
=
r
.
from_
tensor_
4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
alpha
=
torch
.
cat
([
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
...
...
@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_
t
.
rots
.
shape
)
all_rots
=
alpha
.
new_zeros
(
default_
r
.
get_rots
().
get_rot_mats
()
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
)
,
None
)
all_frames
=
default_
t
.
compose
(
all_rots
)
all_frames
=
default_
r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
...
@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
con
cat
(
all_frames_to_bb
=
Rigid
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
...
@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim
=-
1
,
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
r
:
Rigid
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
...
...
@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
t_atoms_to_global
=
r
[...,
None
,
:]
*
group_mask
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
...
...
openfold/utils/loss.py
View file @
e3daf724
...
...
@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -74,8 +74,8 @@ def torsion_angle_loss(
def
compute_fape
(
pred_frames
:
T
,
target_frames
:
T
,
pred_frames
:
Rigid
,
target_frames
:
Rigid
,
frames_mask
:
torch
.
Tensor
,
pred_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
...
...
@@ -111,7 +111,7 @@ def compute_fape(
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
...
...
@@ -123,8 +123,8 @@ def compute_fape(
def
backbone_loss
(
backbone_
affine
_tensor
:
torch
.
Tensor
,
backbone_
affine
_mask
:
torch
.
Tensor
,
backbone_
rigid
_tensor
:
torch
.
Tensor
,
backbone_
rigid
_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
...
...
@@ -132,16 +132,27 @@ def backbone_loss(
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from
_4x4
(
sidechain_frames
)
sidechain_frames
=
Rigid
.
from_tensor
_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from
_4x4
(
renamed_gt_frames
)
renamed_gt_frames
=
Rigid
.
from_tensor
_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
...
...
@@ -550,8 +561,8 @@ def compute_tm(
def
tm_loss
(
logits
,
final_affine_tensor
,
backbone_
affine
_tensor
,
backbone_
affine
_mask
,
backbone_
rigid
_tensor
,
backbone_
rigid
_mask
,
resolution
,
max_bin
=
31
,
no_bins
=
64
,
...
...
@@ -560,16 +571,17 @@ def tm_loss(
eps
=
1e-8
,
**
kwargs
,
):
pred_affine
=
T
.
from_
4x4
(
final_affine_tensor
)
backbone_
affine
=
T
.
from
_4x4
(
backbone_
affine
_tensor
)
pred_affine
=
Rigid
.
from_
tensor_7
(
final_affine_tensor
)
backbone_
rigid
=
Rigid
.
from_tensor
_4x4
(
backbone_
rigid
_tensor
)
def
_points
(
affine
):
pts
=
affine
.
get_trans
()[...,
None
,
:,
:]
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_
affine
))
**
2
,
dim
=-
1
(
_points
(
pred_affine
)
-
_points
(
backbone_
rigid
))
**
2
,
dim
=-
1
)
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
...
...
@@ -583,7 +595,7 @@ def tm_loss(
)
square_mask
=
(
backbone_
affine
_mask
[...,
None
]
*
backbone_
affine
_mask
[...,
None
,
:]
backbone_
rigid
_mask
[...,
None
]
*
backbone_
rigid
_mask
[...,
None
,
:]
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
...
...
@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
),
}
cum_loss
=
0
cum_loss
=
0
.
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
...
...
openfold/utils/
affine
_utils.py
→
openfold/utils/
rigid
_utils.py
View file @
e3daf724
...
...
@@ -107,52 +107,790 @@ def rot_vec_mul(
)
class
T
:
def
identity_rot_mats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
batch_dims
)),
3
,
3
)
rots
=
rots
.
expand
(
*
batch_dims
,
-
1
,
-
1
)
return
rots
def
identity_trans
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
(
*
batch_dims
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
def
identity_quats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
quat
=
torch
.
zeros
(
(
*
batch_dims
,
4
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
with
torch
.
no_grad
():
quat
[...,
0
]
=
1
return
quat
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
mat
[
ind
//
4
][
ind
%
4
]
=
value
return
mat
_QTR_MAT
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_QTR_MAT
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_QTR_MAT
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_QTR_MAT
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_QTR_MAT
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_QTR_MAT
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_QTR_MAT
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_QTR_MAT
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_QTR_MAT
,
requires_grad
=
False
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
rot_to_quat
(
rot
:
torch
.
Tensor
,
):
if
(
rot
.
shape
[
-
2
:]
!=
(
3
,
3
)):
raise
ValueError
(
"Input rotation is incorrectly shaped"
)
rot
=
[[
rot
[...,
i
,
j
]
for
j
in
range
(
3
)]
for
i
in
range
(
3
)]
[[
xx
,
xy
,
xz
],
[
yx
,
yy
,
yz
],
[
zx
,
zy
,
zz
]]
=
rot
k
=
[
[
xx
+
yy
+
zz
,
zy
-
yz
,
xz
-
zx
,
yx
-
xy
,],
[
zy
-
yz
,
xx
-
yy
-
zz
,
xy
+
yx
,
xz
+
zx
,],
[
xz
-
zx
,
xy
+
yx
,
yy
-
xx
-
zz
,
yz
+
zy
,],
[
yx
-
xy
,
xz
+
zx
,
yz
+
zy
,
zz
-
xx
-
yy
,]
]
k
=
(
1.
/
3.
)
*
torch
.
stack
([
torch
.
stack
(
t
,
dim
=-
1
)
for
t
in
k
],
dim
=-
2
)
_
,
vectors
=
torch
.
linalg
.
eigh
(
k
)
return
vectors
[...,
-
1
]
_QUAT_MULTIPLY
=
np
.
zeros
((
4
,
4
,
4
))
_QUAT_MULTIPLY
[:,
:,
0
]
=
[[
1
,
0
,
0
,
0
],
[
0
,
-
1
,
0
,
0
],
[
0
,
0
,
-
1
,
0
],
[
0
,
0
,
0
,
-
1
]]
_QUAT_MULTIPLY
[:,
:,
1
]
=
[[
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
0
]]
_QUAT_MULTIPLY
[:,
:,
2
]
=
[[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
-
1
],
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
]]
_QUAT_MULTIPLY
[:,
:,
3
]
=
[[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
-
1
,
0
,
0
],
[
1
,
0
,
0
,
0
]]
_QUAT_MULTIPLY_BY_VEC
=
_QUAT_MULTIPLY
[:,
1
:,
:]
def
quat_multiply
(
quat1
,
quat2
):
"""Multiply a quaternion by another quaternion."""
mat
=
quat1
.
new_tensor
(
_QUAT_MULTIPLY
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat1
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat1
[...,
:,
None
,
None
]
*
quat2
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
quat_multiply_by_vec
(
quat
,
vec
):
"""Multiply a quaternion by a pure-vector quaternion."""
mat
=
quat
.
new_tensor
(
_QUAT_MULTIPLY_BY_VEC
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat
[...,
:,
None
,
None
]
*
vec
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
invert_rot_mat
(
rot_mat
:
torch
.
Tensor
):
return
rot_mat
.
transpose
(
-
1
,
-
2
)
def
invert_quat
(
quat
:
torch
.
Tensor
):
quat_prime
=
quat
.
clone
()
quat_prime
[...,
1
:]
*=
-
1
inv
=
quat_prime
/
torch
.
sum
(
quat
**
2
,
dim
=-
1
,
keepdim
=
True
)
return
inv
class
Rotation
:
"""
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def
__init__
(
self
,
rot_mats
:
Optional
[
torch
.
Tensor
]
=
None
,
quats
:
Optional
[
torch
.
Tensor
]
=
None
,
normalize_quats
:
bool
=
True
,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if
((
rot_mats
is
None
and
quats
is
None
)
or
(
rot_mats
is
not
None
and
quats
is
not
None
)):
raise
ValueError
(
"Exactly one input argument must be specified"
)
if
((
rot_mats
is
not
None
and
rot_mats
.
shape
[
-
2
:]
!=
(
3
,
3
))
or
(
quats
is
not
None
and
quats
.
shape
[
-
1
]
!=
4
)):
raise
ValueError
(
"Incorrectly shaped rotation matrix or quaternion"
)
if
(
quats
is
not
None
and
normalize_quats
):
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
self
.
_rot_mats
=
rot_mats
self
.
_quats
=
quats
@
staticmethod
def
identity
(
shape
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rotation
:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if
(
fmt
==
"rot_mat"
):
rot_mats
=
identity_rot_mats
(
shape
,
dtype
,
device
,
requires_grad
,
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
fmt
==
"quat"
):
quats
=
identity_quats
(
shape
,
dtype
,
device
,
requires_grad
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
f
"Invalid format: f
{
fmt
}
"
)
# Magic methods
def
__getitem__
(
self
,
index
:
Any
)
->
Rotation
:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
[
index
+
(
slice
(
None
),
slice
(
None
))]
return
Rotation
(
rot_mats
=
rot_mats
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
[
index
+
(
slice
(
None
),)]
return
Rotation
(
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
*
right
[...,
None
,
None
]
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
*
right
[...,
None
]
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return
self
.
__mul__
(
left
)
# Properties
@
property
def
shape
(
self
)
->
torch
.
Size
:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s
=
None
if
(
self
.
_quats
is
not
None
):
s
=
self
.
_quats
.
shape
[:
-
1
]
else
:
s
=
self
.
_rot_mats
.
shape
[:
-
2
]
return
s
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
dtype
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
dtype
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
device
(
self
)
->
torch
.
device
:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
device
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
device
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
requires_grad
(
self
)
->
bool
:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
requires_grad
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
requires_grad
else
:
raise
ValueError
(
"Both rotations are None"
)
def
get_rot_mats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats
=
self
.
_rot_mats
if
(
rot_mats
is
None
):
if
(
self
.
_quats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
rot_mats
=
quat_to_rot
(
self
.
_quats
)
return
rot_mats
def
get_quats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats
=
self
.
_quats
if
(
quats
is
None
):
if
(
self
.
_rot_mats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
quats
=
rot_to_quat
(
self
.
_rot_mats
)
return
quats
def
get_cur_rot
(
self
)
->
torch
.
Tensor
:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
else
:
raise
ValueError
(
"Both rotations are None"
)
# Rotation functions
def
compose_q_update_vec
(
self
,
q_update_vec
:
torch
.
Tensor
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats
=
self
.
get_quats
()
new_quats
=
quats
+
quat_multiply_by_vec
(
quats
,
q_update_vec
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
,
)
def
compose_r
(
self
,
r
:
Rotation
)
->
Rotation
:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1
=
self
.
get_rot_mats
()
r2
=
r
.
get_rot_mats
()
new_rot_mats
=
rot_matmul
(
r1
,
r2
)
return
Rotation
(
rot_mats
=
new_rot_mats
,
quats
=
None
)
def
compose_q
(
self
,
r
:
Rotation
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1
=
self
.
get_quats
()
q2
=
r
.
get_quats
()
new_quats
=
quat_multiply
(
q1
,
q2
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
)
def
apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
return
rot_vec_mul
(
rot_mats
,
pts
)
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
inv_rot_mats
=
invert_rot_mat
(
rot_mats
)
return
rot_vec_mul
(
inv_rot_mats
,
pts
)
def
invert
(
self
)
->
Rotation
:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
invert_rot_mat
(
self
.
_rot_mats
),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
invert_quat
(
self
.
_quats
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
# "Tensor" stuff
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
@
staticmethod
def
cat
(
rs
:
Sequence
[
Rotation
],
dim
:
int
,
)
->
Rigid
:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats
=
[
r
.
get_rot_mats
()
for
r
in
rs
]
rot_mats
=
torch
.
cat
(
rot_mats
,
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
def
map_tensor_fn
(
self
,
fn
:
Callable
[
tensor
.
Tensor
,
tensor
.
Tensor
]
)
->
Rotation
:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
view
(
self
.
_rot_mats
.
shape
[:
-
2
]
+
(
9
,))
rot_mats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rot_mats
,
dim
=-
1
))),
dim
=-
1
)
rot_mats
=
rot_mats
.
view
(
rot_mats
.
shape
[:
-
1
]
+
(
3
,
3
))
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_quats
,
dim
=-
1
))),
dim
=-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
cuda
(
self
)
->
Rotation
:
"""
A class representing an affine transformation. Essentially a wrapper
around two torch tensors: a [*, 3, 3] rotation and a [*, 3]
translation. Designed to behave approximately like a single torch
tensor with the shape of the shared dimensions of its component parts.
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
cuda
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
cuda
(),
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
to
(
self
,
device
:
Optional
[
torch
.
device
],
dtype
:
Optional
[
torch
.
dtype
]
)
->
Rotation
:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
to
(
device
=
device
,
dtype
=
dtype
),
quats
=
None
,
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
to
(
device
=
device
,
dtype
=
dtype
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
detach
(
self
)
->
Rotation
:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
detach
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
detach
(),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
class
Rigid
:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
def
__init__
(
self
,
rots
:
torch
.
Tensor
,
trans
:
torch
.
Tensor
rots
:
Optional
[
Rotation
],
trans
:
Optional
[
torch
.
Tensor
],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
self
.
rots
=
rots
self
.
trans
=
trans
if
self
.
rots
is
None
and
self
.
trans
is
None
:
raise
ValueError
(
"Only one of rots and trans can be None"
)
elif
self
.
rots
is
None
:
self
.
rots
=
T
.
_identity_rot
(
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
dtype
,
self
.
trans
.
device
,
self
.
trans
.
requires_grad
,
# (we need device, dtype, etc. from at least one input)
batch_dims
,
dtype
,
device
,
requires_grad
=
None
,
None
,
None
,
None
if
(
trans
is
not
None
):
batch_dims
=
trans
.
shape
[:
-
1
]
dtype
=
trans
.
dtype
device
=
trans
.
device
requires_grad
=
trans
.
requires_grad
elif
(
rots
is
not
None
):
batch_dims
=
rots
.
shape
dtype
=
rots
.
dtype
device
=
rots
.
device
requires_grad
=
rots
.
requires_grad
else
:
raise
ValueError
(
"At least one input argument must be specified"
)
if
(
rots
is
None
):
rots
=
Rotation
.
identity
(
batch_dims
,
dtype
,
device
,
requires_grad
,
)
elif
self
.
trans
is
None
:
self
.
trans
=
T
.
_identity_trans
(
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
dtype
,
self
.
rots
.
device
,
self
.
rots
.
requires_grad
,
elif
(
trans
is
None
):
trans
=
identity_trans
(
batch_dims
,
dtype
,
device
,
requires_grad
,
)
if
(
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
or
self
.
trans
.
shape
[
-
1
]
!=
3
or
self
.
rots
.
shape
[:
-
2
]
!=
self
.
trans
.
shape
[:
-
1
]
):
raise
ValueError
(
"Incorrectly shaped input"
)
if
((
rots
.
shape
!=
trans
.
shape
[:
-
1
])
or
(
rots
.
device
!=
trans
.
device
)):
raise
ValueError
(
"Rots and trans incompatible"
)
self
.
_rots
=
rots
self
.
_trans
=
trans
@
staticmethod
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rigid
:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return
Rigid
(
Rotation
.
identity
(
shape
,
dtype
,
device
,
requires_grad
,
fmt
=
fmt
),
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
def
__getitem__
(
self
,
index
:
Any
,
)
->
T
:
)
->
Rigid
:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
...
...
@@ -160,11 +898,12 @@ class T:
E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3))
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2,
3, 3
))
assert(indexed.trans.shape == (2, 3))
assert(indexed.
get_
rots
()
.shape == (2,))
assert(indexed.
get_
trans
()
.shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
...
...
@@ -174,54 +913,45 @@ class T:
"""
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
return
T
(
self
.
rots
[
index
+
(
slice
(
None
),
slice
(
None
))],
self
.
trans
[
index
+
(
slice
(
None
),)],
)
def
__eq__
(
self
,
obj
:
T
,
)
->
bool
:
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return
bool
(
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
torch
.
all
(
self
.
trans
==
obj
.
trans
)
return
Rigid
(
self
.
_rots
[
index
],
self
.
_trans
[
index
+
(
slice
(
None
),)],
)
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
T
:
)
->
Rigid
:
"""
Pointwise right multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Pointwise left multiplication of the transformation with a tensor.
Can be used to e.g. mask the Rigid.
Args:
right: The right multiplicand
right:
The tensor multiplicand
Returns:
The product
transformation
The product
"""
rots
=
self
.
rots
*
right
[...,
None
,
None
]
trans
=
self
.
trans
*
right
[...,
None
]
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
new_rots
=
self
.
_rots
*
right
new_trans
=
self
.
_trans
*
right
[...,
None
]
return
T
(
rots
,
trans
)
return
Rigid
(
new_
rots
,
new_
trans
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
T
:
)
->
Rigid
:
"""
Pointwise left multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Reverse pointwise multiplication of the transformation with a
tensor.
Args:
left: The left multiplicand
left:
The left multiplicand
Returns:
The product
transformation
The product
"""
return
self
.
__mul__
(
left
)
...
...
@@ -234,45 +964,74 @@ class T:
Returns:
The shape of the transformation
"""
s
=
self
.
rots
.
shape
[:
-
2
]
return
s
if
len
(
s
)
>
0
else
torch
.
Size
([
1
])
s
=
self
.
_trans
.
shape
[:
-
1
]
return
s
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return
self
.
_trans
.
device
def
get_rots
(
self
):
def
get_rots
(
self
)
->
Rotation
:
"""
Getter for the rotation.
Returns:
The
stored
rotation
.
The rotation
object
"""
return
self
.
rots
return
self
.
_
rots
def
get_trans
(
self
)
->
torch
.
Tensor
:
"""
Getter for the translation.
Returns:
The stored translation
.
The stored translation
"""
return
self
.
trans
return
self
.
_
trans
def
compose
(
self
,
t
:
T
,
)
->
T
:
def
compose
_q_update_vec
(
self
,
q_update_vec
:
torch
.
Tensor
,
)
->
Rigid
:
"""
Composes the transformation with another.
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
t
: The
inner transformation
.
q_vec
: The
quaternion update vector
.
Returns:
The composed transformation.
"""
rot_1
,
trn_1
=
self
.
rots
,
self
.
trans
rot_2
,
trn_2
=
t
.
rots
,
t
.
trans
q_vec
,
t_vec
=
q_update_vec
[...,
:
3
],
q_update_vec
[...,
3
:]
new_rots
=
self
.
_rots
.
compose_q_update_vec
(
q_vec
)
trans_update
=
self
.
_rots
.
apply
(
t_vec
)
new_translation
=
self
.
_trans
+
trans_update
rot
=
rot_matmul
(
rot_1
,
rot_2
)
trn
=
rot_vec_mul
(
rot_1
,
trn_2
)
+
trn_1
return
Rigid
(
new_rots
,
new_translation
)
return
T
(
rot
,
trn
)
def
compose
(
self
,
r
:
Rigid
,
)
->
Rigid
:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot
=
self
.
_rots
.
compose_r
(
r
.
_rots
)
new_trans
=
self
.
_rots
.
apply
(
r
.
_trans
)
+
self
.
_trans
return
Rigid
(
new_rot
,
new_trans
)
def
apply
(
self
,
pts
:
torch
.
Tensor
,
...
...
@@ -285,9 +1044,8 @@ class T:
Returns:
The transformed points.
"""
r
,
t
=
self
.
rots
,
self
.
trans
rotated
=
rot_vec_mul
(
r
,
pts
)
return
rotated
+
t
rotated
=
self
.
_rots
.
apply
(
pts
)
return
rotated
+
self
.
_trans
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
...
...
@@ -300,99 +1058,60 @@ class T:
Returns:
The transformed points.
"""
r
,
t
=
self
.
rots
,
self
.
trans
pts
=
pts
-
t
return
rot_vec_mul
(
r
.
transpose
(
-
1
,
-
2
),
pts
)
pts
=
pts
-
self
.
_trans
return
self
.
_rots
.
invert_apply
(
pts
)
def
invert
(
self
)
->
T
:
def
invert
(
self
)
->
Rigid
:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv
=
self
.
rots
.
transpose
(
-
1
,
-
2
)
trn_inv
=
rot_
vec_mul
(
rot_inv
,
self
.
trans
)
rot_inv
=
self
.
_
rots
.
invert
()
trn_inv
=
rot_
inv
.
apply
(
self
.
_
trans
)
return
T
(
rot_inv
,
-
1
*
trn_inv
)
return
Rigid
(
rot_inv
,
-
1
*
trn_inv
)
def
unsqueeze
(
self
,
dim
:
int
,
)
->
T
:
def
map_tensor_fn
(
self
,
fn
:
Callable
[
tensor
.
Tensor
,
tensor
.
Tensor
]
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Apply a Tensor -> Tensor function to underlying translation and
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
dim: A positive or negative dimension index.
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The
unsqueezed transformation.
The
transformed Rigid object
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
trans
=
self
.
trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
T
(
rots
,
trans
)
@
staticmethod
def
_identity_rot
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
new_rots
=
self
.
_rots
.
map_tensor_fn
(
fn
)
new_trans
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_trans
,
dim
=-
1
))),
dim
=-
1
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
shape
)),
3
,
3
)
rots
=
rots
.
expand
(
*
shape
,
-
1
,
-
1
)
return
rots
@
staticmethod
def
_identity_trans
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
(
*
shape
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
return
Rigid
(
new_rots
,
new_trans
)
@
staticmethod
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
=
True
)
->
T
:
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
"""
Con
structs an identity
transformation.
Con
verts a transformation to a homogenous
transformation
tensor
.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity
transformation
A [*, 4, 4] homogenous
transformation
tensor
"""
return
T
(
T
.
_identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
_identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
tensor
=
self
.
_trans
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
_rots
.
get_rot_mats
()
tensor
[...,
:
3
,
3
]
=
self
.
_trans
tensor
[...,
3
,
3
]
=
1
return
tensor
@
staticmethod
def
from_4x4
(
def
from_
tensor_
4x4
(
t
:
torch
.
Tensor
)
->
T
:
)
->
Rigid
:
"""
Constructs a transformation from a homogenous transformation
tensor.
...
...
@@ -402,35 +1121,45 @@ class T:
Returns:
T object with shape [*]
"""
rots
=
t
[...,
:
3
,
:
3
]
if
(
t
.
shape
[
-
2
:]
!=
(
4
,
4
)):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
rots
=
Rotation
(
rot_mats
=
t
[...,
:
3
,
:
3
],
quats
=
None
)
trans
=
t
[...,
:
3
,
3
]
return
T
(
rots
,
trans
)
def
to_4x4
(
self
)
->
torch
.
Tensor
:
return
Rigid
(
rots
,
trans
)
def
to_tensor_7
(
self
)
->
torch
.
Tensor
:
"""
Converts a transformation to a homogenous transformation tensor.
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
A [*,
4, 4] homogenous
transformation
tensor
A [*,
7] tensor representation of the
transformation
"""
tensor
=
self
.
rot
s
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
3
,
3
]
=
1
tensor
=
self
.
_tran
s
.
new_zeros
((
*
self
.
shape
,
7
))
tensor
[...,
:
4
]
=
self
.
_
rots
.
get_quats
()
tensor
[...,
4
:]
=
self
.
_
trans
return
tensor
@
staticmethod
def
from_tensor
(
t
:
torch
.
Tensor
)
->
T
:
"""
Constructs a transformation from a homogenous transformation
tensor.
def
from_tensor_7
(
t
:
torch
.
Tensor
,
normalize_quats
:
bool
=
False
,
)
->
Rigid
:
if
(
t
.
shape
[
-
1
]
!=
7
):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
quats
,
trans
=
t
[...,
:
4
],
t
[...,
4
:]
rots
=
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
normalize_quats
)
Args:
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return
T
.
from_4x4
(
t
)
return
Rigid
(
rots
,
trans
)
@
staticmethod
def
from_3_points
(
...
...
@@ -438,7 +1167,7 @@ class T:
origin
:
torch
.
Tensor
,
p_xy_plane
:
torch
.
Tensor
,
eps
:
float
=
1e-8
)
->
T
:
)
->
Rigid
:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
...
...
@@ -473,13 +1202,34 @@ class T:
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
rot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
None
)
return
Rigid
(
rot_obj
,
torch
.
stack
(
origin
,
dim
=-
1
))
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
_rots
.
unsqueeze
(
dim
)
trans
=
self
.
_trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rigid
(
rots
,
trans
)
@
staticmethod
def
con
cat
(
ts
:
Sequence
[
T
],
def
cat
(
ts
:
Sequence
[
Rigid
],
dim
:
int
,
)
->
T
:
)
->
Rigid
:
"""
Concatenates transformations along a new dimension.
...
...
@@ -492,57 +1242,60 @@ class T:
Returns:
A concatenated transformation object
"""
rots
=
torch
.
cat
([
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
rots
=
Rotation
.
cat
([
t
.
_
rots
for
t
in
ts
],
dim
)
trans
=
torch
.
cat
(
[
t
.
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
[
t
.
_
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
)
return
T
(
rots
,
trans
)
return
Rigid
(
rots
,
trans
)
def
m
ap
_tensor
_fn
(
self
,
fn
:
Callable
[
[
torch
.
Tensor
],
torch
.
Tensor
])
->
T
:
def
ap
ply_rot
_fn
(
self
,
fn
:
Callable
[
Rotation
,
Rotation
])
->
Rigid
:
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
Applies a Rotation -> Rotation function to the stored rotation
object.
Args:
fn: A function
that takes only a tensor as its argument
fn: A function
of type Rotation -> Rotation
Returns:
The
transform
ed transform
ation object.
A
transformation object
with a transformed rotation
.
"""
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
rots
=
rots
.
view
(
*
rots
.
shape
[:
-
1
],
3
,
3
)
return
Rigid
(
fn
(
self
.
_rots
),
self
.
_trans
)
trans
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
trans
,
-
1
))),
dim
=-
1
)
def
apply_trans_fn
(
self
,
fn
:
Callable
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Rigid
:
"""
Applies a Tensor -> Tensor function to the stored translation.
return
T
(
rots
,
trans
)
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return
Rigid
(
self
.
_rots
,
fn
(
self
.
_trans
))
def
s
top_rot_gradient
(
self
)
->
T
:
def
s
cale_translation
(
self
,
trans_scale_factor
:
float
)
->
Rigid
:
"""
Detach
es the
contained rotation tens
or.
Scal
es the
translation by a constant fact
or.
Args:
trans_scale_factor:
The constant factor
Returns:
A
version of the transformation with detached rot
ation
s
A
transformation object with a scaled transl
ation
.
"""
return
T
(
self
.
rots
.
detach
(),
self
.
trans
)
fn
=
lambda
t
:
t
*
trans_scale_factor
return
self
.
apply_trans_fn
(
fn
)
def
s
cale_translation
(
self
,
factor
:
int
)
->
T
:
def
s
top_rot_gradient
(
self
)
->
Rigid
:
"""
Scal
es the
contained translation tensor by a constant factor.
Detach
es the
underlying rotation object
Returns:
A
version of the transformation with scaled transl
ations
A
transformation object with detached rot
ations
"""
return
T
(
self
.
rots
,
self
.
trans
*
factor
)
fn
=
lambda
r
:
r
.
detach
()
return
self
.
apply_rot_fn
(
fn
)
@
staticmethod
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
...
...
@@ -613,87 +1366,15 @@ class T:
rots
=
rots
.
transpose
(
-
1
,
-
2
)
translation
=
-
1
*
translation
r
eturn
T
(
rots
,
translati
on
)
r
ot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
N
on
e
)
def
cuda
(
self
)
->
T
:
return
Rigid
(
rot_obj
,
translation
)
def
cuda
(
self
)
->
Rigid
:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
mat
[
ind
//
4
][
ind
%
4
]
=
value
return
mat
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_qtr_mat
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_qtr_mat
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_qtr_mat
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_qtr_mat
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
four_by_four
=
vector
.
new_zeros
((
*
vector
.
shape
[:
-
1
],
4
,
4
))
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
return
Rigid
(
self
.
_rots
.
cuda
(),
self
.
_trans
.
cuda
())
tests/test_data/alphafold/common/stereo_chemical_props.txt
0 → 120000
View file @
e3daf724
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
tests/test_feats.py
View file @
e3daf724
...
...
@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
...
@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
...
@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
transforms_repro
=
out
.
to_
tensor_
4x4
().
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
...
...
@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
...
@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
...
...
tests/test_loss.py
View file @
e3daf724
...
...
@@ -20,7 +20,10 @@ import unittest
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
torsion_angle_loss
,
...
...
@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import
haiku
as
hk
def
affine_vector_to_4x4
(
affine
):
r
=
Rigid
.
from_tensor_7
(
affine
)
return
r
.
to_tensor_4x4
()
class
TestLoss
(
unittest
.
TestCase
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
...
...
@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
trans_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
t
=
T
(
rots
,
trans
)
t_gt
=
T
(
rots_gt
,
trans_gt
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
t_gt
=
Rigid
(
Rotation
(
rot_mats
=
rots_gt
)
,
trans_gt
)
frames_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_frames
)).
float
()
positions_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_atoms
)).
float
()
length_scale
=
10
...
...
@@ -686,10 +694,10 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"traj"
]
=
affine_vector_to_4x4
(
value
[
"traj
"
]
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask
"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
...
...
@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f
=
hk
.
transform
(
run_tm_loss
)
np
.
random
.
seed
(
42
)
n_res
=
consts
.
n_res
representations
=
{
...
...
@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
)
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
value
[
"structure_module"
][
"final_affines"
]
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
...
...
tests/test_model.py
View file @
e3daf724
...
...
@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_structure_module.py
View file @
e3daf724
...
...
@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet
,
InvariantPointAttention
,
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
n_res
=
3
c_in
=
5
bu
=
BackboneUpdate
(
c_in
)
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
t
=
bu
(
s
)
rot
,
tra
=
t
.
rots
,
t
.
trans
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
tra
.
shape
==
(
batch_size
,
n_res
,
3
))
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
c_m
=
13
...
...
@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
(
rots
,
trans
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
)
shape_before
=
s
.
shape
s
=
ipa
(
s
,
z
,
t
,
mask
)
s
=
ipa
(
s
,
z
,
r
,
mask
)
self
.
assertTrue
(
s
.
shape
==
shape_before
)
...
...
@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
...
...
tests/test_utils.py
View file @
e3daf724
...
...
@@ -13,11 +13,24 @@
# limitations under the License.
import
math
import
numpy
as
np
import
torch
import
unittest
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
quat_to_rot
,
rot_to_quat
,
)
from
openfold.utils.tensor_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
X_90_ROT
=
torch
.
tensor
(
...
...
@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class
TestUtils
(
unittest
.
TestCase
):
def
test_
T
_from_3_points_shape
(
self
):
def
test_
rigid
_from_3_points_shape
(
self
):
batch_size
=
2
n_res
=
5
...
...
@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
x3
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
.
from_3_points
(
x1
,
x2
,
x3
)
r
=
Rigid
.
from_3_points
(
x1
,
x2
,
x3
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
torch
.
all
(
tra
==
x2
))
def
test_
T
_from_4x4
(
self
):
def
test_
rigid
_from_4x4
(
self
):
batch_size
=
2
transf
=
[
[
1
,
0
,
0
,
1
],
...
...
@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from
_4x4
(
transf
)
r
=
Rigid
.
from_tensor
_4x4
(
transf
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
torch
.
all
(
rot
==
true_rot
.
unsqueeze
(
0
)))
self
.
assertTrue
(
torch
.
all
(
tra
==
true_trans
.
unsqueeze
(
0
)))
def
test_
T
_shape
(
self
):
def
test_
rigid
_shape
(
self
):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
transf
=
Rigid
(
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
def
test_
T_con
cat
(
self
):
def
test_
rigid_
cat
(
self
):
batch_size
=
2
n
=
5
transf
=
T
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
transf
=
Rigid
(
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
transf_
con
cat
=
T
.
con
cat
([
transf
,
transf
],
dim
=
0
)
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
0
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
transf_rots
=
transf
.
get_rots
().
get_rot_mats
()
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
)
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
1
)
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
rots
[:,
:
n
]
==
transf
.
rots
))
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
trans
[:,
:
n
]
==
transf
.
trans
))
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
def
test_T_compose
(
self
):
self
.
assertTrue
(
torch
.
all
(
transf_cat_rots
[:,
:
n
]
==
transf_rots
))
self
.
assertTrue
(
torch
.
all
(
transf_cat
.
get_trans
()[:,
:
n
]
==
transf
.
get_trans
())
)
def
test_rigid_compose
(
self
):
trans_1
=
[
0
,
1
,
0
]
trans_2
=
[
0
,
0
,
1
]
t1
=
T
(
X_90_ROT
,
torch
.
tensor
(
trans_1
))
t2
=
T
(
X_NEG_90_ROT
,
torch
.
tensor
(
trans_2
))
r
=
Rotation
(
rot_mats
=
X_90_ROT
)
t
=
torch
.
tensor
(
trans_1
)
t1
=
Rigid
(
Rotation
(
rot_mats
=
X_90_ROT
),
torch
.
tensor
(
trans_1
)
)
t2
=
Rigid
(
Rotation
(
rot_mats
=
X_NEG_90_ROT
),
torch
.
tensor
(
trans_2
)
)
t3
=
t1
.
compose
(
t2
)
self
.
assertTrue
(
torch
.
all
(
t3
.
rots
==
torch
.
eye
(
3
)))
self
.
assertTrue
(
torch
.
all
(
t3
.
trans
==
0
))
self
.
assertTrue
(
torch
.
all
(
t3
.
get_rots
().
get_rot_mats
()
==
torch
.
eye
(
3
))
)
self
.
assertTrue
(
torch
.
all
(
t3
.
get_trans
()
==
0
)
)
def
test_
T
_apply
(
self
):
def
test_
rigid
_apply
(
self
):
rots
=
torch
.
stack
([
X_90_ROT
,
X_NEG_90_ROT
],
dim
=
0
)
trans
=
torch
.
tensor
([
1
,
1
,
1
])
trans
=
torch
.
stack
([
trans
,
trans
],
dim
=
0
)
t
=
T
(
rots
,
trans
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
x
=
torch
.
arange
(
30
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
...
...
@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps
=
1e-07
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
rot
-
X_90_ROT
)
<
eps
))
def
test_rot_to_quat
(
self
):
quat
=
rot_to_quat
(
X_90_ROT
)
eps
=
1e-07
ans
=
torch
.
tensor
([
math
.
sqrt
(
0.5
),
math
.
sqrt
(
0.5
),
0.
,
0.
])
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
quat
-
ans
)
<
eps
))
def
test_chunk_layer_tensor
(
self
):
x
=
torch
.
rand
(
2
,
4
,
5
,
15
)
l
=
torch
.
nn
.
Linear
(
15
,
30
)
...
...
@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened
=
x_flat
[
i
:
j
]
self
.
assertTrue
(
torch
.
all
(
chunked
==
chunked_flattened
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pre_compose_compare
(
self
):
quat
=
np
.
random
.
rand
(
20
,
4
)
trans
=
[
np
.
random
.
rand
(
20
)
for
_
in
range
(
3
)]
quat_affine
=
alphafold
.
model
.
quat_affine
.
QuatAffine
(
quat
,
translation
=
trans
)
update_vec
=
np
.
random
.
rand
(
20
,
6
)
new_gt
=
quat_affine
.
pre_compose
(
update_vec
)
quat_t
=
torch
.
tensor
(
quat
)
trans_t
=
torch
.
stack
([
torch
.
tensor
(
t
)
for
t
in
trans
],
dim
=-
1
)
rigid
=
Rigid
(
Rotation
(
quats
=
quat_t
),
trans_t
)
new_repro
=
rigid
.
compose_q_update_vec
(
torch
.
tensor
(
update_vec
))
new_gt_q
=
torch
.
tensor
(
np
.
array
(
new_gt
.
quaternion
))
new_gt_t
=
torch
.
stack
(
[
torch
.
tensor
(
np
.
array
(
t
))
for
t
in
new_gt
.
translation
],
dim
=-
1
)
new_repro_q
=
new_repro
.
get_rots
().
get_quats
()
new_repro_t
=
new_repro
.
get_trans
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_q
-
new_repro_q
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_t
-
new_repro_t
))
<
consts
.
eps
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment