Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e3daf724
Commit
e3daf724
authored
Jan 03, 2022
by
Gustaf Ahdritz
Browse files
Overhaul transformation code for better parity w/ AlphaFold
parent
1f709b0d
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1669 additions
and
178 deletions
+1669
-178
openfold/config.py
openfold/config.py
+4
-4
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+18
-12
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+4
-2
openfold/data/templates.py
openfold/data/templates.py
+14
-3
openfold/model/embedders.py
openfold/model/embedders.py
+1
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+59
-51
openfold/utils/feats.py
openfold/utils/feats.py
+18
-18
openfold/utils/loss.py
openfold/utils/loss.py
+35
-22
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+1380
-0
tests/test_data/alphafold/common/stereo_chemical_props.txt
tests/test_data/alphafold/common/stereo_chemical_props.txt
+1
-0
tests/test_feats.py
tests/test_feats.py
+10
-6
tests/test_loss.py
tests/test_loss.py
+18
-10
tests/test_model.py
tests/test_model.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+9
-23
tests/test_utils.py
tests/test_utils.py
+96
-26
No files found.
openfold/config.py
View file @
e3daf724
...
@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
...
@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists"
:
[
NUM_RES
,
None
],
"atom14_gt_exists"
:
[
NUM_RES
,
None
],
"atom14_gt_positions"
:
[
NUM_RES
,
None
,
None
],
"atom14_gt_positions"
:
[
NUM_RES
,
None
,
None
],
"atom37_atom_exists"
:
[
NUM_RES
,
None
],
"atom37_atom_exists"
:
[
NUM_RES
,
None
],
"backbone_
affine
_mask"
:
[
NUM_RES
],
"backbone_
rigid
_mask"
:
[
NUM_RES
],
"backbone_
affine
_tensor"
:
[
NUM_RES
,
None
,
None
],
"backbone_
rigid
_tensor"
:
[
NUM_RES
,
None
,
None
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"bert_mask"
:
[
NUM_MSA_SEQ
,
NUM_RES
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_angles_sin_cos"
:
[
NUM_RES
,
None
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
"chi_mask"
:
[
NUM_RES
,
None
],
...
@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
...
@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos"
:
[
"template_alt_torsion_angles_sin_cos"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
],
"template_backbone_
affine
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
rigid
_mask"
:
[
NUM_TEMPLATES
,
NUM_RES
],
"template_backbone_
affine
_tensor"
:
[
"template_backbone_
rigid
_tensor"
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
,
],
],
"template_mask"
:
[
NUM_TEMPLATES
],
"template_mask"
:
[
NUM_TEMPLATES
],
...
...
openfold/data/data_transforms.py
View file @
e3daf724
...
@@ -22,7 +22,7 @@ import torch
...
@@ -22,7 +22,7 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
as
rc
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
...
@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return
protein
return
protein
def
atom37_to_frames
(
protein
):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
aatype
=
protein
[
"aatype"
]
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
...
@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
...
@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
)
gt_frames
=
T
.
from_3_points
(
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
1e-8
,
eps
=
eps
,
)
)
group_exists
=
batched_gather
(
group_exists
=
batched_gather
(
...
@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
...
@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
*
((
1
,)
*
batch_dims
),
21
,
8
...
@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
...
@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
)
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
gt_frames_tensor
=
gt_frames
.
to_4x4
()
gt_frames_tensor
=
gt_frames
.
to_
tensor_
4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_
tensor_
4x4
()
protein
[
"rigidgroups_gt_frames"
]
=
gt_frames_tensor
protein
[
"rigidgroups_gt_frames"
]
=
gt_frames_tensor
protein
[
"rigidgroups_gt_exists"
]
=
gt_exists
protein
[
"rigidgroups_gt_exists"
]
=
gt_exists
...
@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
...
@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim
=-
1
,
dim
=-
1
,
)
)
torsion_frames
=
T
.
from_3_points
(
torsion_frames
=
Rigid
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
torsions_atom_pos
[...,
0
,
:],
...
@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
...
@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def
get_backbone_frames
(
protein
):
def
get_backbone_frames
(
protein
):
#
TODO: Verify that this is correct
#
DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein
[
"backbone_
affine
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
protein
[
"backbone_
rigid
_tensor"
]
=
protein
[
"rigidgroups_gt_frames"
][
...,
0
,
:,
:
...,
0
,
:,
:
]
]
protein
[
"backbone_
affine
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
protein
[
"backbone_
rigid
_mask"
]
=
protein
[
"rigidgroups_gt_exists"
][...,
0
]
return
protein
return
protein
...
...
openfold/data/mmcif_parsing.py
View file @
e3daf724
...
@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
...
@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def
get_atom_coords
(
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
zero_center
:
bool
=
True
mmcif_object
:
MmcifObject
,
chain_id
:
str
,
_zero_center_positions
:
bool
=
True
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
...
@@ -475,7 +477,7 @@ def get_atom_coords(
...
@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions
[
res_index
]
=
pos
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
if
zero_center
:
if
_
zero_center
_positions
:
binary_mask
=
all_atom_mask
.
astype
(
bool
)
binary_mask
=
all_atom_mask
.
astype
(
bool
)
translation_vec
=
all_atom_positions
[
binary_mask
].
mean
(
axis
=
0
)
translation_vec
=
all_atom_positions
[
binary_mask
].
mean
(
axis
=
0
)
all_atom_positions
[
binary_mask
]
-=
translation_vec
all_atom_positions
[
binary_mask
]
-=
translation_vec
...
...
openfold/data/templates.py
View file @
e3daf724
...
@@ -503,10 +503,13 @@ def _get_atom_positions(
...
@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
auth_chain_id
:
str
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
,
max_ca_ca_distance
:
float
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
,
_zero_center_positions
=
_zero_center_positions
,
)
)
all_atom_positions
,
all_atom_mask
=
coords_with_mask
all_atom_positions
,
all_atom_mask
=
coords_with_mask
_check_residue_distances
(
_check_residue_distances
(
...
@@ -523,6 +526,7 @@ def _extract_template_features(
...
@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence
:
str
,
query_sequence
:
str
,
template_chain_id
:
str
,
template_chain_id
:
str
,
kalign_binary_path
:
str
,
kalign_binary_path
:
str
,
_zero_center_positions
:
bool
=
True
,
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
"""Parses atom positions in the target structure and aligns with the query.
"""Parses atom positions in the target structure and aligns with the query.
...
@@ -607,7 +611,10 @@ def _extract_template_features(
...
@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
# they're really really bad.
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
,
_zero_center_positions
=
_zero_center_positions
,
)
)
except
(
CaDistanceError
,
KeyError
)
as
ex
:
except
(
CaDistanceError
,
KeyError
)
as
ex
:
raise
NoAtomDataInTemplateError
(
raise
NoAtomDataInTemplateError
(
...
@@ -795,6 +802,7 @@ def _process_single_hit(
...
@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs
:
Mapping
[
str
,
str
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
kalign_binary_path
:
str
,
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
_zero_center_positions
:
bool
=
True
,
)
->
SingleHitResult
:
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
# Fail hard if we can't get the PDB ID and chain name from the hit.
...
@@ -856,6 +864,7 @@ def _process_single_hit(
...
@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
hit_chain_id
,
template_chain_id
=
hit_chain_id
,
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
...
@@ -913,7 +922,6 @@ class TemplateSearchResult:
...
@@ -913,7 +922,6 @@ class TemplateSearchResult:
class
TemplateHitFeaturizer
:
class
TemplateHitFeaturizer
:
"""A class for turning hhr hits to template features."""
"""A class for turning hhr hits to template features."""
def
__init__
(
def
__init__
(
self
,
self
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
...
@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
...
@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_path
:
Optional
[
str
]
=
None
,
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
_shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
_shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
_zero_center_positions
:
bool
=
True
,
):
):
"""Initializes the Template Search.
"""Initializes the Template Search.
...
@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
...
@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self
.
_obsolete_pdbs
=
{}
self
.
_obsolete_pdbs
=
{}
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
def
get_templates
(
def
get_templates
(
self
,
self
,
...
@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
...
@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
_zero_center_positions
=
self
.
_zero_center_positions
,
)
)
if
result
.
error
:
if
result
.
error
:
...
...
openfold/model/embedders.py
View file @
e3daf724
...
@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
device
=
x
.
device
,
requires_grad
=
False
,
)
)
# [*, N, C_m]
# [*, N, C_m]
...
...
openfold/model/structure_module.py
View file @
e3daf724
...
@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
...
@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
torsion_angles_to_frames
,
)
)
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
dict_multimap
,
permute_final_dims
,
permute_final_dims
,
...
@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self
,
self
,
s
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
t
:
T
,
r
:
Rigid
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
z:
z:
[*, N_res, N_res, C_z] pair representation
[*, N_res, N_res, C_z] pair representation
t
:
r
:
[*, N_res]
affine
transformation object
[*, N_res] transformation object
mask:
mask:
[*, N_res] mask
[*, N_res] mask
Returns:
Returns:
...
@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3]
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
t
[...,
None
].
apply
(
q_pts
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
q_pts
=
q_pts
.
view
(
...
@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3]
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
t
[...,
None
].
apply
(
kv_pts
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, (P_q + P_v), 3]
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
...
@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
o_pt_norm
=
flatten_final_dims
(
...
@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class
BackboneUpdate
(
nn
.
Module
):
class
BackboneUpdate
(
nn
.
Module
):
"""
"""
Implements Algorithm 23.
Implements
part of
Algorithm 23.
"""
"""
def
__init__
(
self
,
c_s
):
def
__init__
(
self
,
c_s
):
...
@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
...
@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self
.
linear
=
Linear
(
self
.
c_s
,
6
,
init
=
"final"
)
self
.
linear
=
Linear
(
self
.
c_s
,
6
,
init
=
"final"
)
def
forward
(
self
,
s
)
:
def
forward
(
self
,
s
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""
"""
Args:
Args:
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
Returns:
Returns:
[*, N_res
] affine transformation object
[*, N_res
, 6] update vector
"""
"""
# [*, 6]
# [*, 6]
params
=
self
.
linear
(
s
)
update
=
self
.
linear
(
s
)
# [*, 3]
quats
,
trans
=
params
[...,
:
3
],
params
[...,
3
:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# [*, 3]
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,)
)
# [*, 4]
quats
=
torch
.
cat
([
ones
,
quats
],
dim
=-
1
)
quats
=
quats
/
norm_denom
[...,
None
]
# [*, 3, 3]
return
update
rots
=
quat_to_rot
(
quats
)
return
T
(
rots
,
trans
)
class
StructureModuleTransitionLayer
(
nn
.
Module
):
class
StructureModuleTransitionLayer
(
nn
.
Module
):
...
@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
...
@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self
,
self
,
s
,
s
,
z
,
z
,
f
,
aatype
,
mask
=
None
,
mask
=
None
,
):
):
"""
"""
...
@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
...
@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
z:
z:
[*, N_res, N_res, C_z] pair representation
[*, N_res, N_res, C_z] pair representation
f
:
aatype
:
[*, N_res] amino acid indices
[*, N_res] amino acid indices
mask:
mask:
Optional [*, N_res] sequence mask
Optional [*, N_res] sequence mask
...
@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
...
@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s
=
self
.
linear_in
(
s
)
s
=
self
.
linear_in
(
s
)
# [*, N]
# [*, N]
t
=
T
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
)
rigids
=
Rigid
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
,
fmt
=
"quat"
,
)
outputs
=
[]
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
t
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
# [*, N]
t
=
t
.
compose
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global
=
Rigid
(
Rotation
(
rot_mats
=
rigids
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
rigids
.
get_trans
(),
)
backb_to_global
=
backb_to_global
.
scale_translation
(
self
.
trans_scale_factor
)
# [*, N, 7, 2]
# [*, N, 7, 2]
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
ngles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
)
,
backb_to_global
,
a
,
a
ngles
,
f
,
aatype
,
)
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
all_frames_to_global
,
all_frames_to_global
,
f
,
aatype
,
)
)
scaled_rigids
=
rigids
.
scale_translation
(
self
.
trans_scale_factor
)
preds
=
{
preds
=
{
"frames"
:
t
.
scale
_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"frames"
:
scale
d_rigids
.
to_tensor_7
(),
"sidechain_frames"
:
all_frames_to_global
.
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_
tensor_
4x4
(),
"unnormalized_angles"
:
unnormalized_a
,
"unnormalized_angles"
:
unnormalized_a
ngles
,
"angles"
:
a
,
"angles"
:
a
ngles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
,
}
}
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
if
i
<
(
self
.
no_blocks
-
1
):
t
=
t
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
...
@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
...
@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
)
)
if
self
.
group_idx
is
None
:
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
)
)
if
self
.
atom_mask
is
None
:
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
)
)
if
self
.
lit_positions
is
None
:
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
self
,
t
,
f
# [*, N, 8] # [*, N]
self
,
r
,
f
# [*, N, 8] # [*, N]
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
r
.
get_
rots
()
.
dtype
,
r
.
get_
rots
()
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
t
,
r
,
f
,
f
,
self
.
default_frames
,
self
.
default_frames
,
self
.
group_idx
,
self
.
group_idx
,
...
...
openfold/utils/feats.py
View file @
e3daf724
...
@@ -22,7 +22,7 @@ from typing import Dict
...
@@ -22,7 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
import
openfold.np.residue_constants
as
rc
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
one_hot
,
one_hot
,
...
@@ -124,18 +124,16 @@ def build_template_pair_feat(
...
@@ -124,18 +124,16 @@ def build_template_pair_feat(
)
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
# TODO: Consider running this in double precision
rigids
=
Rigid
.
make_transform_from_reference
(
affines
=
T
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
eps
=
eps
,
eps
=
eps
,
)
)
points
=
rigids
.
get_trans
()[...,
None
,
:,
:]
rigid_vec
=
rigids
[...,
None
].
invert_apply
(
points
)
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
rigid_vec
**
2
,
dim
=-
1
))
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
))
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
template_mask
=
(
...
@@ -144,7 +142,7 @@ def build_template_pair_feat(
...
@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
affine
_vec
*
inv_distance_scalar
[...,
None
]
unit_vector
=
rigid
_vec
*
inv_distance_scalar
[...,
None
]
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
...
@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def
torsion_angles_to_frames
(
def
torsion_angles_to_frames
(
t
:
T
,
r
:
Rigid
,
alpha
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
...
@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
...
@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e.
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
# One [*, N, 8, 3] translation matrix
default_
t
=
T
.
from_4x4
(
default_4x4
)
default_
r
=
r
.
from_
tensor_
4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
# [*, N, 8, 2]
alpha
=
torch
.
cat
([
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
alpha
=
torch
.
cat
(
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# Produces rotation matrices of the form:
...
@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
...
@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_
t
.
rots
.
shape
)
all_rots
=
alpha
.
new_zeros
(
default_
r
.
get_rots
().
get_rot_mats
()
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
T
(
all_rots
,
None
)
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
)
,
None
)
all_frames
=
default_
t
.
compose
(
all_rots
)
all_frames
=
default_
r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
...
@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
con
cat
(
all_frames_to_bb
=
Rigid
.
cat
(
[
[
all_frames
[...,
:
5
],
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
...
@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim
=-
1
,
dim
=-
1
,
)
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
t
:
T
,
r
:
Rigid
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
default_frames
,
default_frames
,
group_idx
,
group_idx
,
...
@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
)
)
# [*, N, 14, 8]
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
t_atoms_to_global
=
r
[...,
None
,
:]
*
group_mask
# [*, N, 14]
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
...
...
openfold/utils/loss.py
View file @
e3daf724
...
@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
...
@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils
import
feats
from
openfold.utils.
affine
_utils
import
T
from
openfold.utils.
rigid
_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -74,8 +74,8 @@ def torsion_angle_loss(
...
@@ -74,8 +74,8 @@ def torsion_angle_loss(
def
compute_fape
(
def
compute_fape
(
pred_frames
:
T
,
pred_frames
:
Rigid
,
target_frames
:
T
,
target_frames
:
Rigid
,
frames_mask
:
torch
.
Tensor
,
frames_mask
:
torch
.
Tensor
,
pred_positions
:
torch
.
Tensor
,
pred_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
...
@@ -111,7 +111,7 @@ def compute_fape(
...
@@ -111,7 +111,7 @@ def compute_fape(
# )
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
...
@@ -123,8 +123,8 @@ def compute_fape(
...
@@ -123,8 +123,8 @@ def compute_fape(
def
backbone_loss
(
def
backbone_loss
(
backbone_
affine
_tensor
:
torch
.
Tensor
,
backbone_
rigid
_tensor
:
torch
.
Tensor
,
backbone_
affine
_mask
:
torch
.
Tensor
,
backbone_
rigid
_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
clamp_distance
:
float
=
10.0
,
...
@@ -132,16 +132,27 @@ def backbone_loss(
...
@@ -132,16 +132,27 @@ def backbone_loss(
eps
:
float
=
1e-4
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
fape_loss
=
compute_fape
(
fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
[
None
],
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -150,10 +161,10 @@ def backbone_loss(
...
@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss
=
compute_fape
(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
[
None
],
gt_aff
[
None
],
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_
affine
_mask
[
None
],
backbone_
rigid
_mask
[
None
],
l1_clamp_distance
=
None
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -193,9 +204,9 @@ def sidechain_loss(
...
@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames
=
sidechain_frames
[
-
1
]
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
sidechain_frames
=
T
.
from
_4x4
(
sidechain_frames
)
sidechain_frames
=
Rigid
.
from_tensor
_4x4
(
sidechain_frames
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
renamed_gt_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
)
renamed_gt_frames
=
T
.
from
_4x4
(
renamed_gt_frames
)
renamed_gt_frames
=
Rigid
.
from_tensor
_4x4
(
renamed_gt_frames
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
rigidgroups_gt_exists
=
rigidgroups_gt_exists
.
reshape
(
*
batch_dims
,
-
1
)
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
[
-
1
]
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
sidechain_atom_pos
=
sidechain_atom_pos
.
view
(
*
batch_dims
,
-
1
,
3
)
...
@@ -422,7 +433,7 @@ def distogram_loss(
...
@@ -422,7 +433,7 @@ def distogram_loss(
device
=
logits
.
device
,
device
=
logits
.
device
,
)
)
boundaries
=
boundaries
**
2
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
dists
=
torch
.
sum
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
dim
=-
1
,
...
@@ -550,8 +561,8 @@ def compute_tm(
...
@@ -550,8 +561,8 @@ def compute_tm(
def
tm_loss
(
def
tm_loss
(
logits
,
logits
,
final_affine_tensor
,
final_affine_tensor
,
backbone_
affine
_tensor
,
backbone_
rigid
_tensor
,
backbone_
affine
_mask
,
backbone_
rigid
_mask
,
resolution
,
resolution
,
max_bin
=
31
,
max_bin
=
31
,
no_bins
=
64
,
no_bins
=
64
,
...
@@ -560,16 +571,17 @@ def tm_loss(
...
@@ -560,16 +571,17 @@ def tm_loss(
eps
=
1e-8
,
eps
=
1e-8
,
**
kwargs
,
**
kwargs
,
):
):
pred_affine
=
T
.
from_
4x4
(
final_affine_tensor
)
pred_affine
=
Rigid
.
from_
tensor_7
(
final_affine_tensor
)
backbone_
affine
=
T
.
from
_4x4
(
backbone_
affine
_tensor
)
backbone_
rigid
=
Rigid
.
from_tensor
_4x4
(
backbone_
rigid
_tensor
)
def
_points
(
affine
):
def
_points
(
affine
):
pts
=
affine
.
get_trans
()[...,
None
,
:,
:]
pts
=
affine
.
get_trans
()[...,
None
,
:,
:]
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
return
affine
.
invert
()[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
(
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_
affine
))
**
2
,
dim
=-
1
(
_points
(
pred_affine
)
-
_points
(
backbone_
rigid
))
**
2
,
dim
=-
1
)
)
sq_diff
=
sq_diff
.
detach
()
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
...
@@ -583,7 +595,7 @@ def tm_loss(
...
@@ -583,7 +595,7 @@ def tm_loss(
)
)
square_mask
=
(
square_mask
=
(
backbone_
affine
_mask
[...,
None
]
*
backbone_
affine
_mask
[...,
None
,
:]
backbone_
rigid
_mask
[...,
None
]
*
backbone_
rigid
_mask
[...,
None
,
:]
)
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
...
@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
),
),
}
}
cum_loss
=
0
cum_loss
=
0
.
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
if
weight
:
loss
=
loss_fn
()
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
...
...
openfold/utils/
affine
_utils.py
→
openfold/utils/
rigid
_utils.py
View file @
e3daf724
...
@@ -106,53 +106,791 @@ def rot_vec_mul(
...
@@ -106,53 +106,791 @@ def rot_vec_mul(
dim
=-
1
,
dim
=-
1
,
)
)
def
identity_rot_mats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
batch_dims
)),
3
,
3
)
rots
=
rots
.
expand
(
*
batch_dims
,
-
1
,
-
1
)
return
rots
def
identity_trans
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
(
*
batch_dims
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
def
identity_quats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
quat
=
torch
.
zeros
(
(
*
batch_dims
,
4
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
with
torch
.
no_grad
():
quat
[...,
0
]
=
1
return
quat
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
mat
[
ind
//
4
][
ind
%
4
]
=
value
return
mat
_QTR_MAT
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_QTR_MAT
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_QTR_MAT
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_QTR_MAT
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_QTR_MAT
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_QTR_MAT
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_QTR_MAT
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_QTR_MAT
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_QTR_MAT
,
requires_grad
=
False
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
rot_to_quat
(
rot
:
torch
.
Tensor
,
):
if
(
rot
.
shape
[
-
2
:]
!=
(
3
,
3
)):
raise
ValueError
(
"Input rotation is incorrectly shaped"
)
rot
=
[[
rot
[...,
i
,
j
]
for
j
in
range
(
3
)]
for
i
in
range
(
3
)]
[[
xx
,
xy
,
xz
],
[
yx
,
yy
,
yz
],
[
zx
,
zy
,
zz
]]
=
rot
k
=
[
[
xx
+
yy
+
zz
,
zy
-
yz
,
xz
-
zx
,
yx
-
xy
,],
[
zy
-
yz
,
xx
-
yy
-
zz
,
xy
+
yx
,
xz
+
zx
,],
[
xz
-
zx
,
xy
+
yx
,
yy
-
xx
-
zz
,
yz
+
zy
,],
[
yx
-
xy
,
xz
+
zx
,
yz
+
zy
,
zz
-
xx
-
yy
,]
]
k
=
(
1.
/
3.
)
*
torch
.
stack
([
torch
.
stack
(
t
,
dim
=-
1
)
for
t
in
k
],
dim
=-
2
)
_
,
vectors
=
torch
.
linalg
.
eigh
(
k
)
return
vectors
[...,
-
1
]
_QUAT_MULTIPLY
=
np
.
zeros
((
4
,
4
,
4
))
_QUAT_MULTIPLY
[:,
:,
0
]
=
[[
1
,
0
,
0
,
0
],
[
0
,
-
1
,
0
,
0
],
[
0
,
0
,
-
1
,
0
],
[
0
,
0
,
0
,
-
1
]]
_QUAT_MULTIPLY
[:,
:,
1
]
=
[[
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
0
]]
_QUAT_MULTIPLY
[:,
:,
2
]
=
[[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
-
1
],
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
]]
_QUAT_MULTIPLY
[:,
:,
3
]
=
[[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
-
1
,
0
,
0
],
[
1
,
0
,
0
,
0
]]
_QUAT_MULTIPLY_BY_VEC
=
_QUAT_MULTIPLY
[:,
1
:,
:]
def
quat_multiply
(
quat1
,
quat2
):
"""Multiply a quaternion by another quaternion."""
mat
=
quat1
.
new_tensor
(
_QUAT_MULTIPLY
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat1
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat1
[...,
:,
None
,
None
]
*
quat2
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
quat_multiply_by_vec
(
quat
,
vec
):
"""Multiply a quaternion by a pure-vector quaternion."""
mat
=
quat
.
new_tensor
(
_QUAT_MULTIPLY_BY_VEC
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat
[...,
:,
None
,
None
]
*
vec
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
invert_rot_mat
(
rot_mat
:
torch
.
Tensor
):
return
rot_mat
.
transpose
(
-
1
,
-
2
)
class
T
:
def
invert_quat
(
quat
:
torch
.
Tensor
):
quat_prime
=
quat
.
clone
()
quat_prime
[...,
1
:]
*=
-
1
inv
=
quat_prime
/
torch
.
sum
(
quat
**
2
,
dim
=-
1
,
keepdim
=
True
)
return
inv
class
Rotation
:
"""
"""
A class representing an affine transformation. Essentially a wrapper
A 3D rotation. Depending on how the object is initialized, the
around two torch tensors: a [*, 3, 3] rotation and a [*, 3]
rotation is represented by either a rotation matrix or a
translation. Designed to behave approximately like a single torch
quaternion, though both formats are made available by helper functions.
tensor with the shape of the shared dimensions of its component parts.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def
__init__
(
self
,
rot_mats
:
Optional
[
torch
.
Tensor
]
=
None
,
quats
:
Optional
[
torch
.
Tensor
]
=
None
,
normalize_quats
:
bool
=
True
,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if
((
rot_mats
is
None
and
quats
is
None
)
or
(
rot_mats
is
not
None
and
quats
is
not
None
)):
raise
ValueError
(
"Exactly one input argument must be specified"
)
if
((
rot_mats
is
not
None
and
rot_mats
.
shape
[
-
2
:]
!=
(
3
,
3
))
or
(
quats
is
not
None
and
quats
.
shape
[
-
1
]
!=
4
)):
raise
ValueError
(
"Incorrectly shaped rotation matrix or quaternion"
)
if
(
quats
is
not
None
and
normalize_quats
):
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
self
.
_rot_mats
=
rot_mats
self
.
_quats
=
quats
@
staticmethod
def
identity
(
shape
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rotation
:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if
(
fmt
==
"rot_mat"
):
rot_mats
=
identity_rot_mats
(
shape
,
dtype
,
device
,
requires_grad
,
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
fmt
==
"quat"
):
quats
=
identity_quats
(
shape
,
dtype
,
device
,
requires_grad
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
f
"Invalid format: f
{
fmt
}
"
)
# Magic methods
def
__getitem__
(
self
,
index
:
Any
)
->
Rotation
:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
[
index
+
(
slice
(
None
),
slice
(
None
))]
return
Rotation
(
rot_mats
=
rot_mats
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
[
index
+
(
slice
(
None
),)]
return
Rotation
(
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
*
right
[...,
None
,
None
]
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
*
right
[...,
None
]
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return
self
.
__mul__
(
left
)
# Properties
@
property
def
shape
(
self
)
->
torch
.
Size
:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s
=
None
if
(
self
.
_quats
is
not
None
):
s
=
self
.
_quats
.
shape
[:
-
1
]
else
:
s
=
self
.
_rot_mats
.
shape
[:
-
2
]
return
s
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
dtype
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
dtype
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
device
(
self
)
->
torch
.
device
:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
device
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
device
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
requires_grad
(
self
)
->
bool
:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
requires_grad
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
requires_grad
else
:
raise
ValueError
(
"Both rotations are None"
)
def
get_rot_mats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats
=
self
.
_rot_mats
if
(
rot_mats
is
None
):
if
(
self
.
_quats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
rot_mats
=
quat_to_rot
(
self
.
_quats
)
return
rot_mats
def
get_quats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats
=
self
.
_quats
if
(
quats
is
None
):
if
(
self
.
_rot_mats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
quats
=
rot_to_quat
(
self
.
_rot_mats
)
return
quats
def
get_cur_rot
(
self
)
->
torch
.
Tensor
:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
else
:
raise
ValueError
(
"Both rotations are None"
)
# Rotation functions
def
compose_q_update_vec
(
self
,
q_update_vec
:
torch
.
Tensor
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats
=
self
.
get_quats
()
new_quats
=
quats
+
quat_multiply_by_vec
(
quats
,
q_update_vec
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
,
)
def
compose_r
(
self
,
r
:
Rotation
)
->
Rotation
:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1
=
self
.
get_rot_mats
()
r2
=
r
.
get_rot_mats
()
new_rot_mats
=
rot_matmul
(
r1
,
r2
)
return
Rotation
(
rot_mats
=
new_rot_mats
,
quats
=
None
)
def
compose_q
(
self
,
r
:
Rotation
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1
=
self
.
get_quats
()
q2
=
r
.
get_quats
()
new_quats
=
quat_multiply
(
q1
,
q2
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
)
def
apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
return
rot_vec_mul
(
rot_mats
,
pts
)
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
inv_rot_mats
=
invert_rot_mat
(
rot_mats
)
return
rot_vec_mul
(
inv_rot_mats
,
pts
)
def
invert
(
self
)
->
Rotation
:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
invert_rot_mat
(
self
.
_rot_mats
),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
invert_quat
(
self
.
_quats
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
# "Tensor" stuff
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
@
staticmethod
def
cat
(
rs
:
Sequence
[
Rotation
],
dim
:
int
,
)
->
Rigid
:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats
=
[
r
.
get_rot_mats
()
for
r
in
rs
]
rot_mats
=
torch
.
cat
(
rot_mats
,
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
def
map_tensor_fn
(
self
,
fn
:
Callable
[
tensor
.
Tensor
,
tensor
.
Tensor
]
)
->
Rotation
:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
view
(
self
.
_rot_mats
.
shape
[:
-
2
]
+
(
9
,))
rot_mats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rot_mats
,
dim
=-
1
))),
dim
=-
1
)
rot_mats
=
rot_mats
.
view
(
rot_mats
.
shape
[:
-
1
]
+
(
3
,
3
))
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_quats
,
dim
=-
1
))),
dim
=-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
cuda
(
self
)
->
Rotation
:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
cuda
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
cuda
(),
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
to
(
self
,
device
:
Optional
[
torch
.
device
],
dtype
:
Optional
[
torch
.
dtype
]
)
->
Rotation
:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
to
(
device
=
device
,
dtype
=
dtype
),
quats
=
None
,
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
to
(
device
=
device
,
dtype
=
dtype
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
detach
(
self
)
->
Rotation
:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
detach
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
detach
(),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
class
Rigid
:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
rots
:
torch
.
Tensor
,
rots
:
Optional
[
Rotation
],
trans
:
torch
.
Tensor
trans
:
Optional
[
torch
.
Tensor
],
):
):
"""
"""
Args:
Args:
rots: A [*, 3, 3] rotation tensor
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
trans: A corresponding [*, 3] translation tensor
"""
"""
self
.
rots
=
rots
# (we need device, dtype, etc. from at least one input)
self
.
trans
=
trans
batch_dims
,
dtype
,
device
,
requires_grad
=
None
,
None
,
None
,
None
if
self
.
rots
is
None
and
self
.
trans
is
None
:
if
(
trans
is
not
None
):
raise
ValueError
(
"Only one of rots and trans can be None"
)
batch_dims
=
trans
.
shape
[:
-
1
]
elif
self
.
rots
is
None
:
dtype
=
trans
.
dtype
self
.
rots
=
T
.
_identity_rot
(
device
=
trans
.
device
self
.
trans
.
shape
[:
-
1
],
requires_grad
=
trans
.
requires_grad
self
.
trans
.
dtype
,
elif
(
rots
is
not
None
):
self
.
trans
.
device
,
batch_dims
=
rots
.
shape
self
.
trans
.
requires_grad
,
dtype
=
rots
.
dtype
device
=
rots
.
device
requires_grad
=
rots
.
requires_grad
else
:
raise
ValueError
(
"At least one input argument must be specified"
)
if
(
rots
is
None
):
rots
=
Rotation
.
identity
(
batch_dims
,
dtype
,
device
,
requires_grad
,
)
)
elif
self
.
trans
is
None
:
elif
(
trans
is
None
):
self
.
trans
=
T
.
_identity_trans
(
trans
=
identity_trans
(
self
.
rots
.
shape
[:
-
2
],
batch_dims
,
dtype
,
device
,
requires_grad
,
self
.
rots
.
dtype
,
self
.
rots
.
device
,
self
.
rots
.
requires_grad
,
)
)
if
(
if
((
rots
.
shape
!=
trans
.
shape
[:
-
1
])
or
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
(
rots
.
device
!=
trans
.
device
)):
or
self
.
trans
.
shape
[
-
1
]
!=
3
raise
ValueError
(
"Rots and trans incompatible"
)
or
self
.
rots
.
shape
[:
-
2
]
!=
self
.
trans
.
shape
[:
-
1
]
):
self
.
_rots
=
rots
raise
ValueError
(
"Incorrectly shaped input"
)
self
.
_trans
=
trans
@
staticmethod
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rigid
:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return
Rigid
(
Rotation
.
identity
(
shape
,
dtype
,
device
,
requires_grad
,
fmt
=
fmt
),
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
def
__getitem__
(
self
,
def
__getitem__
(
self
,
index
:
Any
,
index
:
Any
,
)
->
T
:
)
->
Rigid
:
"""
"""
Indexes the affine transformation with PyTorch-style indices.
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
The index is applied to the shared dimensions of both the rotation
...
@@ -160,11 +898,12 @@ class T:
...
@@ -160,11 +898,12 @@ class T:
E.g.::
E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3))
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2,
3, 3
))
assert(indexed.
get_
rots
()
.shape == (2,))
assert(indexed.trans.shape == (2, 3))
assert(indexed.
get_
trans
()
.shape == (2, 3))
Args:
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
index: A standard torch tensor index. E.g. 8, (10, None, 3),
...
@@ -174,54 +913,45 @@ class T:
...
@@ -174,54 +913,45 @@ class T:
"""
"""
if
type
(
index
)
!=
tuple
:
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
index
=
(
index
,)
return
T
(
self
.
rots
[
index
+
(
slice
(
None
),
slice
(
None
))],
return
Rigid
(
self
.
trans
[
index
+
(
slice
(
None
),)],
self
.
_rots
[
index
],
)
self
.
_trans
[
index
+
(
slice
(
None
),)],
def
__eq__
(
self
,
obj
:
T
,
)
->
bool
:
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return
bool
(
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
torch
.
all
(
self
.
trans
==
obj
.
trans
)
)
)
def
__mul__
(
self
,
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
right
:
torch
.
Tensor
,
)
->
T
:
)
->
Rigid
:
"""
"""
Pointwise right multiplication of the affine transformation with a
Pointwise left multiplication of the transformation with a tensor.
tensor. Multiplication is broadcast over the rotation/translation
Can be used to e.g. mask the Rigid.
dimensions.
Args:
Args:
right: The right multiplicand
right:
The tensor multiplicand
Returns:
Returns:
The product
transformation
The product
"""
"""
rots
=
self
.
rots
*
right
[...,
None
,
None
]
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
trans
=
self
.
trans
*
right
[...,
None
]
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
return
T
(
rots
,
trans
)
new_rots
=
self
.
_rots
*
right
new_trans
=
self
.
_trans
*
right
[...,
None
]
def
__rmul__
(
self
,
return
Rigid
(
new_rots
,
new_trans
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
left
:
torch
.
Tensor
,
)
->
T
:
)
->
Rigid
:
"""
"""
Pointwise left multiplication of the affine transformation with a
Reverse pointwise multiplication of the transformation with a
tensor. Multiplication is broadcast over the rotation/translation
tensor.
dimensions.
Args:
Args:
left: The left multiplicand
left:
The left multiplicand
Returns:
Returns:
The product
transformation
The product
"""
"""
return
self
.
__mul__
(
left
)
return
self
.
__mul__
(
left
)
...
@@ -234,45 +964,74 @@ class T:
...
@@ -234,45 +964,74 @@ class T:
Returns:
Returns:
The shape of the transformation
The shape of the transformation
"""
"""
s
=
self
.
rots
.
shape
[:
-
2
]
s
=
self
.
_trans
.
shape
[:
-
1
]
return
s
if
len
(
s
)
>
0
else
torch
.
Size
([
1
])
return
s
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Returns the device on which the Rigid's tensors are located.
def
get_rots
(
self
):
Returns:
The device on which the Rigid's tensors are located
"""
return
self
.
_trans
.
device
def
get_rots
(
self
)
->
Rotation
:
"""
"""
Getter for the rotation.
Getter for the rotation.
Returns:
Returns:
The
stored
rotation
.
The rotation
object
"""
"""
return
self
.
rots
return
self
.
_
rots
def
get_trans
(
self
)
->
torch
.
Tensor
:
def
get_trans
(
self
)
->
torch
.
Tensor
:
"""
"""
Getter for the translation.
Getter for the translation.
Returns:
Returns:
The stored translation
.
The stored translation
"""
"""
return
self
.
trans
return
self
.
_
trans
def
compose
(
self
,
def
compose
_q_update_vec
(
self
,
t
:
T
,
q_update_vec
:
torch
.
Tensor
,
)
->
T
:
)
->
Rigid
:
"""
"""
Composes the transformation with another.
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
Args:
t
: The
inner transformation
.
q_vec
: The
quaternion update vector
.
Returns:
Returns:
The composed transformation.
The composed transformation.
"""
"""
rot_1
,
trn_1
=
self
.
rots
,
self
.
trans
q_vec
,
t_vec
=
q_update_vec
[...,
:
3
],
q_update_vec
[...,
3
:]
rot_2
,
trn_2
=
t
.
rots
,
t
.
trans
new_rots
=
self
.
_rots
.
compose_q_update_vec
(
q_vec
)
trans_update
=
self
.
_rots
.
apply
(
t_vec
)
new_translation
=
self
.
_trans
+
trans_update
rot
=
rot_matmul
(
rot_1
,
rot_2
)
return
Rigid
(
new_rots
,
new_translation
)
trn
=
rot_vec_mul
(
rot_1
,
trn_2
)
+
trn_1
return
T
(
rot
,
trn
)
def
compose
(
self
,
r
:
Rigid
,
)
->
Rigid
:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot
=
self
.
_rots
.
compose_r
(
r
.
_rots
)
new_trans
=
self
.
_rots
.
apply
(
r
.
_trans
)
+
self
.
_trans
return
Rigid
(
new_rot
,
new_trans
)
def
apply
(
self
,
def
apply
(
self
,
pts
:
torch
.
Tensor
,
pts
:
torch
.
Tensor
,
...
@@ -285,9 +1044,8 @@ class T:
...
@@ -285,9 +1044,8 @@ class T:
Returns:
Returns:
The transformed points.
The transformed points.
"""
"""
r
,
t
=
self
.
rots
,
self
.
trans
rotated
=
self
.
_rots
.
apply
(
pts
)
rotated
=
rot_vec_mul
(
r
,
pts
)
return
rotated
+
self
.
_trans
return
rotated
+
t
def
invert_apply
(
self
,
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
pts
:
torch
.
Tensor
...
@@ -300,99 +1058,60 @@ class T:
...
@@ -300,99 +1058,60 @@ class T:
Returns:
Returns:
The transformed points.
The transformed points.
"""
"""
r
,
t
=
self
.
rots
,
self
.
trans
pts
=
pts
-
self
.
_trans
pts
=
pts
-
t
return
self
.
_rots
.
invert_apply
(
pts
)
return
rot_vec_mul
(
r
.
transpose
(
-
1
,
-
2
),
pts
)
def
invert
(
self
)
->
T
:
def
invert
(
self
)
->
Rigid
:
"""
"""
Inverts the transformation.
Inverts the transformation.
Returns:
Returns:
The inverse transformation.
The inverse transformation.
"""
"""
rot_inv
=
self
.
rots
.
transpose
(
-
1
,
-
2
)
rot_inv
=
self
.
_
rots
.
invert
()
trn_inv
=
rot_
vec_mul
(
rot_inv
,
self
.
trans
)
trn_inv
=
rot_
inv
.
apply
(
self
.
_
trans
)
return
T
(
rot_inv
,
-
1
*
trn_inv
)
return
Rigid
(
rot_inv
,
-
1
*
trn_inv
)
def
unsqueeze
(
self
,
def
map_tensor_fn
(
self
,
dim
:
int
,
fn
:
Callable
[
tensor
.
Tensor
,
tensor
.
Tensor
]
)
->
T
:
)
->
Rigid
:
"""
"""
Analogous to torch.unsqueeze. The dimension is relative to the
Apply a Tensor -> Tensor function to underlying translation and
shared dimensions of the rotation/translation.
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
Args:
dim: A positive or negative dimension index.
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
Returns:
The unsqueezed transformation.
The transformed Rigid object
"""
"""
if
dim
>=
len
(
self
.
shape
):
new_rots
=
self
.
_rots
.
map_tensor_fn
(
fn
)
raise
ValueError
(
"Invalid dimension"
)
new_trans
=
torch
.
stack
(
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_trans
,
dim
=-
1
))),
trans
=
self
.
trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
dim
=-
1
return
T
(
rots
,
trans
)
@
staticmethod
def
_identity_rot
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
shape
)),
3
,
3
)
rots
=
rots
.
expand
(
*
shape
,
-
1
,
-
1
)
return
rots
@
staticmethod
return
Rigid
(
new_rots
,
new_trans
)
def
_identity_trans
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
(
*
shape
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
@
staticmethod
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
=
True
)
->
T
:
"""
"""
Con
structs an identity
transformation.
Con
verts a transformation to a homogenous
transformation
tensor
.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
Returns:
The identity
transformation
A [*, 4, 4] homogenous
transformation
tensor
"""
"""
return
T
(
tensor
=
self
.
_trans
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
T
.
_identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
tensor
[...,
:
3
,
:
3
]
=
self
.
_rots
.
get_rot_mats
()
T
.
_identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
tensor
[...,
:
3
,
3
]
=
self
.
_trans
)
tensor
[...,
3
,
3
]
=
1
return
tensor
@
staticmethod
@
staticmethod
def
from_4x4
(
def
from_
tensor_
4x4
(
t
:
torch
.
Tensor
t
:
torch
.
Tensor
)
->
T
:
)
->
Rigid
:
"""
"""
Constructs a transformation from a homogenous transformation
Constructs a transformation from a homogenous transformation
tensor.
tensor.
...
@@ -402,35 +1121,45 @@ class T:
...
@@ -402,35 +1121,45 @@ class T:
Returns:
Returns:
T object with shape [*]
T object with shape [*]
"""
"""
rots
=
t
[...,
:
3
,
:
3
]
if
(
t
.
shape
[
-
2
:]
!=
(
4
,
4
)):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
rots
=
Rotation
(
rot_mats
=
t
[...,
:
3
,
:
3
],
quats
=
None
)
trans
=
t
[...,
:
3
,
3
]
trans
=
t
[...,
:
3
,
3
]
return
T
(
rots
,
trans
)
return
Rigid
(
rots
,
trans
)
def
to_
4x4
(
self
)
->
torch
.
Tensor
:
def
to_
tensor_7
(
self
)
->
torch
.
Tensor
:
"""
"""
Converts a transformation to a homogenous transformation tensor.
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
Returns:
A [*,
4, 4] homogenous
transformation
tensor
A [*,
7] tensor representation of the
transformation
"""
"""
tensor
=
self
.
rot
s
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
=
self
.
_tran
s
.
new_zeros
((
*
self
.
shape
,
7
))
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
4
]
=
self
.
_
rots
.
get_quats
()
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
4
:]
=
self
.
_
trans
tensor
[...,
3
,
3
]
=
1
return
tensor
return
tensor
@
staticmethod
@
staticmethod
def
from_tensor
(
t
:
torch
.
Tensor
)
->
T
:
def
from_tensor_7
(
"""
t
:
torch
.
Tensor
,
Constructs a transformation from a homogenous transformation
normalize_quats
:
bool
=
False
,
tensor.
)
->
Rigid
:
if
(
t
.
shape
[
-
1
]
!=
7
):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
quats
,
trans
=
t
[...,
:
4
],
t
[...,
4
:]
rots
=
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
normalize_quats
)
Args:
return
Rigid
(
rots
,
trans
)
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return
T
.
from_4x4
(
t
)
@
staticmethod
@
staticmethod
def
from_3_points
(
def
from_3_points
(
...
@@ -438,7 +1167,7 @@ class T:
...
@@ -438,7 +1167,7 @@ class T:
origin
:
torch
.
Tensor
,
origin
:
torch
.
Tensor
,
p_xy_plane
:
torch
.
Tensor
,
p_xy_plane
:
torch
.
Tensor
,
eps
:
float
=
1e-8
eps
:
float
=
1e-8
)
->
T
:
)
->
Rigid
:
"""
"""
Implements algorithm 21. Constructs transformations from sets of 3
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
points using the Gram-Schmidt algorithm.
...
@@ -473,13 +1202,34 @@ class T:
...
@@ -473,13 +1202,34 @@ class T:
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
rot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
None
)
return
Rigid
(
rot_obj
,
torch
.
stack
(
origin
,
dim
=-
1
))
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
_rots
.
unsqueeze
(
dim
)
trans
=
self
.
_trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rigid
(
rots
,
trans
)
@
staticmethod
@
staticmethod
def
con
cat
(
def
cat
(
ts
:
Sequence
[
T
],
ts
:
Sequence
[
Rigid
],
dim
:
int
,
dim
:
int
,
)
->
T
:
)
->
Rigid
:
"""
"""
Concatenates transformations along a new dimension.
Concatenates transformations along a new dimension.
...
@@ -492,57 +1242,60 @@ class T:
...
@@ -492,57 +1242,60 @@ class T:
Returns:
Returns:
A concatenated transformation object
A concatenated transformation object
"""
"""
rots
=
torch
.
cat
([
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
rots
=
Rotation
.
cat
([
t
.
_
rots
for
t
in
ts
],
dim
)
trans
=
torch
.
cat
(
trans
=
torch
.
cat
(
[
t
.
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
[
t
.
_
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
)
)
return
T
(
rots
,
trans
)
return
Rigid
(
rots
,
trans
)
def
m
ap
_tensor
_fn
(
self
,
fn
:
Callable
[
[
torch
.
Tensor
],
torch
.
Tensor
])
->
T
:
def
ap
ply_rot
_fn
(
self
,
fn
:
Callable
[
Rotation
,
Rotation
])
->
Rigid
:
"""
"""
Apply a function that takes a tensor as its only argument to the
Applies a Rotation -> Rotation function to the stored rotation
rotations and translations, treating the final two/one
object.
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
Args:
Args:
fn: A function
that takes only a tensor as its argument
fn: A function
of type Rotation -> Rotation
Returns:
Returns:
The
transform
ed transform
ation object.
A
transformation object
with a transformed rotation
.
"""
"""
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
return
Rigid
(
fn
(
self
.
_rots
),
self
.
_trans
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
rots
=
rots
.
view
(
*
rots
.
shape
[:
-
1
],
3
,
3
)
trans
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
trans
,
-
1
))),
dim
=-
1
)
def
apply_trans_fn
(
self
,
fn
:
Callable
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Rigid
:
"""
Applies a Tensor -> Tensor function to the stored translation.
return
T
(
rots
,
trans
)
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return
Rigid
(
self
.
_rots
,
fn
(
self
.
_trans
))
def
s
top_rot_gradient
(
self
)
->
T
:
def
s
cale_translation
(
self
,
trans_scale_factor
:
float
)
->
Rigid
:
"""
"""
Detach
es the
contained rotation tens
or.
Scal
es the
translation by a constant fact
or.
Args:
trans_scale_factor:
The constant factor
Returns:
Returns:
A
version of the transformation with detached rot
ation
s
A
transformation object with a scaled transl
ation
.
"""
"""
return
T
(
self
.
rots
.
detach
(),
self
.
trans
)
fn
=
lambda
t
:
t
*
trans_scale_factor
return
self
.
apply_trans_fn
(
fn
)
def
s
cale_translation
(
self
,
factor
:
int
)
->
T
:
def
s
top_rot_gradient
(
self
)
->
Rigid
:
"""
"""
Scal
es the
contained translation tensor by a constant factor.
Detach
es the
underlying rotation object
Returns:
Returns:
A
version of the transformation with scaled transl
ations
A
transformation object with detached rot
ations
"""
"""
return
T
(
self
.
rots
,
self
.
trans
*
factor
)
fn
=
lambda
r
:
r
.
detach
()
return
self
.
apply_rot_fn
(
fn
)
@
staticmethod
@
staticmethod
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
...
@@ -613,87 +1366,15 @@ class T:
...
@@ -613,87 +1366,15 @@ class T:
rots
=
rots
.
transpose
(
-
1
,
-
2
)
rots
=
rots
.
transpose
(
-
1
,
-
2
)
translation
=
-
1
*
translation
translation
=
-
1
*
translation
return
T
(
rots
,
translation
)
rot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
None
)
return
Rigid
(
rot_obj
,
translation
)
def
cuda
(
self
)
->
T
:
def
cuda
(
self
)
->
Rigid
:
"""
"""
Moves the transformation object to GPU memory
Moves the transformation object to GPU memory
Returns:
Returns:
A version of the transformation on GPU
A version of the transformation on GPU
"""
"""
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
return
Rigid
(
self
.
_rots
.
cuda
(),
self
.
_trans
.
cuda
())
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
mat
[
ind
//
4
][
ind
%
4
]
=
value
return
mat
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_qtr_mat
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_qtr_mat
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_qtr_mat
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_qtr_mat
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
four_by_four
=
vector
.
new_zeros
((
*
vector
.
shape
[:
-
1
],
4
,
4
))
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
tests/test_data/alphafold/common/stereo_chemical_props.txt
0 → 120000
View file @
e3daf724
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
tests/test_feats.py
View file @
e3daf724
...
@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
...
@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
...
@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n
=
5
n
=
5
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
angles
=
torch
.
rand
((
batch_size
,
n
,
7
,
2
))
...
@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
...
@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
torsion_angles_sin_cos
=
np
.
random
.
rand
(
n_res
,
7
,
2
)
...
@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
...
@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row
[...,
3
]
=
1
bottom_row
[...,
3
]
=
1
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_gt
=
torch
.
cat
([
transforms_gt
,
bottom_row
],
dim
=-
2
)
transforms_repro
=
out
.
to_4x4
().
cpu
()
transforms_repro
=
out
.
to_
tensor_
4x4
().
cpu
()
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
torch
.
max
(
torch
.
abs
(
transforms_gt
-
transforms_repro
)
<
consts
.
eps
)
...
@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
...
@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
8
,
3
))
ts
=
T
(
rots
,
trans
)
ts
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n_res
)).
long
()
...
@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
...
@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,
8
))
affines
=
random_affines_4x4
((
n_res
,
8
))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
out_gt
=
f
.
apply
({},
None
,
aatype
,
rigids
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
out_gt
)
...
...
tests/test_loss.py
View file @
e3daf724
...
@@ -20,7 +20,10 @@ import unittest
...
@@ -20,7 +20,10 @@ import unittest
import
ml_collections
as
mlc
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
)
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
torsion_angle_loss
,
torsion_angle_loss
,
...
@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
...
@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import
haiku
as
hk
import
haiku
as
hk
def
affine_vector_to_4x4
(
affine
):
r
=
Rigid
.
from_tensor_7
(
affine
)
return
r
.
to_tensor_4x4
()
class
TestLoss
(
unittest
.
TestCase
):
class
TestLoss
(
unittest
.
TestCase
):
def
test_run_torsion_angle_loss
(
self
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
...
@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
,
3
))
rots_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
trans_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
trans_gt
=
torch
.
rand
((
batch_size
,
n_frames
,
3
))
t
=
T
(
rots
,
trans
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
t_gt
=
T
(
rots_gt
,
trans_gt
)
t_gt
=
Rigid
(
Rotation
(
rot_mats
=
rots_gt
)
,
trans_gt
)
frames_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_frames
)).
float
()
frames_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_frames
)).
float
()
positions_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_atoms
)).
float
()
positions_mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_atoms
)).
float
()
length_scale
=
10
length_scale
=
10
...
@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
...
@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
batch
[
"backbone_affine_tensor"
]
)
)
value
[
"traj"
]
=
affine_vector_to_4x4
(
value
[
"traj
"
]
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask
"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
...
@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
...
@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f
=
hk
.
transform
(
run_tm_loss
)
f
=
hk
.
transform
(
run_tm_loss
)
np
.
random
.
seed
(
42
)
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
representations
=
{
representations
=
{
...
@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
...
@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
value
=
tree_map
(
to_tensor
,
value
,
np
.
ndarray
)
batch
[
"backbone_
affine
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_
rigid
_tensor"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_affine_tensor"
]
batch
[
"backbone_affine_tensor"
]
)
)
value
[
"structure_module"
][
"final_affines"
]
=
affine_vector_to_4x4
(
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
value
[
"structure_module"
][
"final_affines"
]
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
logits
=
model
.
aux_heads
.
tm
(
representations
[
"pair"
])
...
...
tests/test_model.py
View file @
e3daf724
...
@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
...
@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
out_repro
=
out_repro
.
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
1e-3
))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_structure_module.py
View file @
e3daf724
...
@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
...
@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet
,
AngleResnet
,
InvariantPointAttention
,
InvariantPointAttention
,
)
)
from
openfold.utils.affine_utils
import
T
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
(
s
,
z
,
f
)
out
=
sm
(
s
,
z
,
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
)
...
@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
...
@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
class
TestBackboneUpdate
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
n_res
=
3
c_in
=
5
bu
=
BackboneUpdate
(
c_in
)
s
=
torch
.
rand
((
batch_size
,
n_res
,
c_in
))
t
=
bu
(
s
)
rot
,
tra
=
t
.
rots
,
t
.
trans
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
tra
.
shape
==
(
batch_size
,
n_res
,
3
))
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_m
=
13
c_m
=
13
...
@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rots
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
(
rots
,
trans
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
)
)
shape_before
=
s
.
shape
shape_before
=
s
.
shape
s
=
ipa
(
s
,
z
,
t
,
mask
)
s
=
ipa
(
s
,
z
,
r
,
mask
)
self
.
assertTrue
(
s
.
shape
==
shape_before
)
self
.
assertTrue
(
s
.
shape
==
shape_before
)
...
@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
T
.
from_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
())
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
sample_affine
=
quats
...
...
tests/test_utils.py
View file @
e3daf724
...
@@ -13,11 +13,24 @@
...
@@ -13,11 +13,24 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
import
numpy
as
np
import
torch
import
torch
import
unittest
import
unittest
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
quat_to_rot
,
rot_to_quat
,
)
from
openfold.utils.tensor_utils
import
chunk_layer
,
_chunk_slice
from
openfold.utils.tensor_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
import
jax
import
haiku
as
hk
X_90_ROT
=
torch
.
tensor
(
X_90_ROT
=
torch
.
tensor
(
...
@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
...
@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class
TestUtils
(
unittest
.
TestCase
):
class
TestUtils
(
unittest
.
TestCase
):
def
test_
T
_from_3_points_shape
(
self
):
def
test_
rigid
_from_3_points_shape
(
self
):
batch_size
=
2
batch_size
=
2
n_res
=
5
n_res
=
5
...
@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
...
@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
x2
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
x3
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
x3
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
t
=
T
.
from_3_points
(
x1
,
x2
,
x3
)
r
=
Rigid
.
from_3_points
(
x1
,
x2
,
x3
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
rot
.
shape
==
(
batch_size
,
n_res
,
3
,
3
))
self
.
assertTrue
(
torch
.
all
(
tra
==
x2
))
self
.
assertTrue
(
torch
.
all
(
tra
==
x2
))
def
test_
T
_from_4x4
(
self
):
def
test_
rigid
_from_4x4
(
self
):
batch_size
=
2
batch_size
=
2
transf
=
[
transf
=
[
[
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
1
],
...
@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
...
@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
transf
=
torch
.
stack
([
transf
for
_
in
range
(
batch_size
)],
dim
=
0
)
t
=
T
.
from
_4x4
(
transf
)
r
=
Rigid
.
from_tensor
_4x4
(
transf
)
rot
,
tra
=
t
.
rots
,
t
.
trans
rot
,
tra
=
r
.
get_rots
().
get_rot_mats
(),
r
.
get_
trans
()
self
.
assertTrue
(
torch
.
all
(
rot
==
true_rot
.
unsqueeze
(
0
)))
self
.
assertTrue
(
torch
.
all
(
rot
==
true_rot
.
unsqueeze
(
0
)))
self
.
assertTrue
(
torch
.
all
(
tra
==
true_trans
.
unsqueeze
(
0
)))
self
.
assertTrue
(
torch
.
all
(
tra
==
true_trans
.
unsqueeze
(
0
)))
def
test_
T
_shape
(
self
):
def
test_
rigid
_shape
(
self
):
batch_size
=
2
batch_size
=
2
n
=
5
n
=
5
transf
=
T
(
transf
=
Rigid
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
)
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
self
.
assertTrue
(
transf
.
shape
==
(
batch_size
,
n
))
def
test_
T_con
cat
(
self
):
def
test_
rigid_
cat
(
self
):
batch_size
=
2
batch_size
=
2
n
=
5
n
=
5
transf
=
T
(
transf
=
Rigid
(
torch
.
rand
((
batch_size
,
n
,
3
,
3
)),
torch
.
rand
((
batch_size
,
n
,
3
))
Rotation
(
rot_mats
=
torch
.
rand
((
batch_size
,
n
,
3
,
3
))),
torch
.
rand
((
batch_size
,
n
,
3
))
)
)
transf_
con
cat
=
T
.
con
cat
([
transf
,
transf
],
dim
=
0
)
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
0
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
))
transf_rots
=
transf
.
get_rots
().
get_rot_mats
()
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
transf_concat
=
T
.
concat
([
transf
,
transf
],
dim
=
1
)
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
*
2
,
n
,
3
,
3
)
)
self
.
assertTrue
(
transf_concat
.
rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
transf_cat
=
Rigid
.
cat
([
transf
,
transf
],
dim
=
1
)
transf_cat_rots
=
transf_cat
.
get_rots
().
get_rot_mats
()
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
rots
[:,
:
n
]
==
transf
.
rots
))
self
.
assertTrue
(
transf_cat_rots
.
shape
==
(
batch_size
,
n
*
2
,
3
,
3
))
self
.
assertTrue
(
torch
.
all
(
transf_concat
.
trans
[:,
:
n
]
==
transf
.
trans
))
def
test_T_compose
(
self
):
self
.
assertTrue
(
torch
.
all
(
transf_cat_rots
[:,
:
n
]
==
transf_rots
))
self
.
assertTrue
(
torch
.
all
(
transf_cat
.
get_trans
()[:,
:
n
]
==
transf
.
get_trans
())
)
def
test_rigid_compose
(
self
):
trans_1
=
[
0
,
1
,
0
]
trans_1
=
[
0
,
1
,
0
]
trans_2
=
[
0
,
0
,
1
]
trans_2
=
[
0
,
0
,
1
]
t1
=
T
(
X_90_ROT
,
torch
.
tensor
(
trans_1
))
r
=
Rotation
(
rot_mats
=
X_90_ROT
)
t2
=
T
(
X_NEG_90_ROT
,
torch
.
tensor
(
trans_2
))
t
=
torch
.
tensor
(
trans_1
)
t1
=
Rigid
(
Rotation
(
rot_mats
=
X_90_ROT
),
torch
.
tensor
(
trans_1
)
)
t2
=
Rigid
(
Rotation
(
rot_mats
=
X_NEG_90_ROT
),
torch
.
tensor
(
trans_2
)
)
t3
=
t1
.
compose
(
t2
)
t3
=
t1
.
compose
(
t2
)
self
.
assertTrue
(
torch
.
all
(
t3
.
rots
==
torch
.
eye
(
3
)))
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
all
(
t3
.
trans
==
0
))
torch
.
all
(
t3
.
get_rots
().
get_rot_mats
()
==
torch
.
eye
(
3
))
)
self
.
assertTrue
(
torch
.
all
(
t3
.
get_trans
()
==
0
)
)
def
test_
T
_apply
(
self
):
def
test_
rigid
_apply
(
self
):
rots
=
torch
.
stack
([
X_90_ROT
,
X_NEG_90_ROT
],
dim
=
0
)
rots
=
torch
.
stack
([
X_90_ROT
,
X_NEG_90_ROT
],
dim
=
0
)
trans
=
torch
.
tensor
([
1
,
1
,
1
])
trans
=
torch
.
tensor
([
1
,
1
,
1
])
trans
=
torch
.
stack
([
trans
,
trans
],
dim
=
0
)
trans
=
torch
.
stack
([
trans
,
trans
],
dim
=
0
)
t
=
T
(
rots
,
trans
)
t
=
Rigid
(
Rotation
(
rot_mats
=
rots
)
,
trans
)
x
=
torch
.
arange
(
30
)
x
=
torch
.
arange
(
30
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
x
=
torch
.
stack
([
x
,
x
],
dim
=
0
)
...
@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
...
@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps
=
1e-07
eps
=
1e-07
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
rot
-
X_90_ROT
)
<
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
rot
-
X_90_ROT
)
<
eps
))
def
test_rot_to_quat
(
self
):
quat
=
rot_to_quat
(
X_90_ROT
)
eps
=
1e-07
ans
=
torch
.
tensor
([
math
.
sqrt
(
0.5
),
math
.
sqrt
(
0.5
),
0.
,
0.
])
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
quat
-
ans
)
<
eps
))
def
test_chunk_layer_tensor
(
self
):
def
test_chunk_layer_tensor
(
self
):
x
=
torch
.
rand
(
2
,
4
,
5
,
15
)
x
=
torch
.
rand
(
2
,
4
,
5
,
15
)
l
=
torch
.
nn
.
Linear
(
15
,
30
)
l
=
torch
.
nn
.
Linear
(
15
,
30
)
...
@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
...
@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened
=
x_flat
[
i
:
j
]
chunked_flattened
=
x_flat
[
i
:
j
]
self
.
assertTrue
(
torch
.
all
(
chunked
==
chunked_flattened
))
self
.
assertTrue
(
torch
.
all
(
chunked
==
chunked_flattened
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pre_compose_compare
(
self
):
quat
=
np
.
random
.
rand
(
20
,
4
)
trans
=
[
np
.
random
.
rand
(
20
)
for
_
in
range
(
3
)]
quat_affine
=
alphafold
.
model
.
quat_affine
.
QuatAffine
(
quat
,
translation
=
trans
)
update_vec
=
np
.
random
.
rand
(
20
,
6
)
new_gt
=
quat_affine
.
pre_compose
(
update_vec
)
quat_t
=
torch
.
tensor
(
quat
)
trans_t
=
torch
.
stack
([
torch
.
tensor
(
t
)
for
t
in
trans
],
dim
=-
1
)
rigid
=
Rigid
(
Rotation
(
quats
=
quat_t
),
trans_t
)
new_repro
=
rigid
.
compose_q_update_vec
(
torch
.
tensor
(
update_vec
))
new_gt_q
=
torch
.
tensor
(
np
.
array
(
new_gt
.
quaternion
))
new_gt_t
=
torch
.
stack
(
[
torch
.
tensor
(
np
.
array
(
t
))
for
t
in
new_gt
.
translation
],
dim
=-
1
)
new_repro_q
=
new_repro
.
get_rots
().
get_quats
()
new_repro_t
=
new_repro
.
get_trans
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_q
-
new_repro_q
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
new_gt_t
-
new_repro_t
))
<
consts
.
eps
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment