Commit e3daf724 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Overhaul transformation code for better parity w/ AlphaFold

parent 1f709b0d
...@@ -89,8 +89,8 @@ config = mlc.ConfigDict( ...@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists": [NUM_RES, None], "atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None], "atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None], "atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES], "backbone_rigid_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None], "backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES], "bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None], "chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None], "chi_mask": [NUM_RES, None],
...@@ -126,8 +126,8 @@ config = mlc.ConfigDict( ...@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos": [ "template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES], "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [ "template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_mask": [NUM_TEMPLATES], "template_mask": [NUM_TEMPLATES],
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -752,7 +752,7 @@ def make_atom14_positions(protein): ...@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return protein return protein
def atom37_to_frames(protein): def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
...@@ -810,11 +810,11 @@ def atom37_to_frames(protein): ...@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
gt_frames = T.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :], p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8, eps=eps,
) )
group_exists = batched_gather( group_exists = batched_gather(
...@@ -836,8 +836,9 @@ def atom37_to_frames(protein): ...@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(T(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
...@@ -871,10 +872,15 @@ def atom37_to_frames(protein): ...@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None)) residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_4x4() gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists protein["rigidgroups_gt_exists"] = gt_exists
...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles( ...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim=-1, dim=-1,
) )
torsion_frames = T.from_3_points( torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :], torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :], torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :], torsions_atom_pos[..., 0, :],
...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles( ...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein): def get_backbone_frames(protein):
# TODO: Verify that this is correct # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, : ..., 0, :, :
] ]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0] protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein return protein
......
...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool: ...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
...@@ -475,7 +477,7 @@ def get_atom_coords( ...@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
if zero_center: if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool) binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0) translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec all_atom_positions[binary_mask] -= translation_vec
......
...@@ -503,10 +503,13 @@ def _get_atom_positions( ...@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
) )
all_atom_positions, all_atom_mask = coords_with_mask all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances( _check_residue_distances(
...@@ -523,6 +526,7 @@ def _extract_template_features( ...@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence: str, query_sequence: str,
template_chain_id: str, template_chain_id: str,
kalign_binary_path: str, kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]: ) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query. """Parses atom positions in the target structure and aligns with the query.
...@@ -607,7 +611,10 @@ def _extract_template_features( ...@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless # Essentially set to infinity - we don't want to reject templates unless
# they're really really bad. # they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions( all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0 mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
) )
except (CaDistanceError, KeyError) as ex: except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError( raise NoAtomDataInTemplateError(
...@@ -795,6 +802,7 @@ def _process_single_hit( ...@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str, kalign_binary_path: str,
strict_error_check: bool = False, strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult: ) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit.""" """Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
...@@ -856,6 +864,7 @@ def _process_single_hit( ...@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs] features["template_sum_probs"] = [hit.sum_probs]
...@@ -913,7 +922,6 @@ class TemplateSearchResult: ...@@ -913,7 +922,6 @@ class TemplateSearchResult:
class TemplateHitFeaturizer: class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features.""" """A class for turning hhr hits to template features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer: ...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path: Optional[str] = None, obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False, strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None, _shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
): ):
"""Initializes the Template Search. """Initializes the Template Search.
...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer: ...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates( def get_templates(
self, self,
...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer: ...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path, kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
) )
if result.error: if result.error:
......
...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
device=x.device, device=x.device,
requires_grad=False,
) )
# [*, N, C_m] # [*, N, C_m]
......
...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import ( ...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import ( from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
) )
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
permute_final_dims, permute_final_dims,
...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module): ...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
t: T, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module): ...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
t: r:
[*, N_res] affine transformation object [*, N_res] transformation object
mask: mask:
[*, N_res] mask [*, N_res] mask
Returns: Returns:
...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module): ...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3] # [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1) q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts) q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_q, 3]
q_pts = q_pts.view( q_pts = q_pts.view(
...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module): ...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3] # [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1) kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = t[..., None].apply(kv_pts) kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3] # [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module): ...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt) o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims( o_pt_norm = flatten_final_dims(
...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module): ...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class BackboneUpdate(nn.Module): class BackboneUpdate(nn.Module):
""" """
Implements Algorithm 23. Implements part of Algorithm 23.
""" """
def __init__(self, c_s): def __init__(self, c_s):
...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module): ...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self.linear = Linear(self.c_s, 6, init="final") self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s): def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
Returns: Returns:
[*, N_res] affine transformation object [*, N_res, 6] update vector
""" """
# [*, 6] # [*, 6]
params = self.linear(s) update = self.linear(s)
# [*, 3]
quats, trans = params[..., :3], params[..., 3:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom[..., None]
# [*, 3, 3] return update
rots = quat_to_rot(quats)
return T(rots, trans)
class StructureModuleTransitionLayer(nn.Module): class StructureModuleTransitionLayer(nn.Module):
...@@ -592,7 +573,7 @@ class StructureModule(nn.Module): ...@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self, self,
s, s,
z, z,
f, aatype,
mask=None, mask=None,
): ):
""" """
...@@ -601,7 +582,7 @@ class StructureModule(nn.Module): ...@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
f: aatype:
[*, N_res] amino acid indices [*, N_res] amino acid indices
mask: mask:
Optional [*, N_res] sequence mask Optional [*, N_res] sequence mask
...@@ -623,44 +604,67 @@ class StructureModule(nn.Module): ...@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s = self.linear_in(s) s = self.linear_in(s)
# [*, N] # [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training) rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, t, mask) s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
t = t.compose(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2] # [*, N, 7, 2]
unnormalized_a, a = self.angle_resnet(s, s_initial) unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames( all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), backb_to_global,
a, angles,
f, aatype,
) )
pred_xyz = self.frames_and_literature_positions_to_atom14_pos( pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global, all_frames_to_global,
f, aatype,
) )
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = { preds = {
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(), "frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_4x4(), "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_a, "unnormalized_angles": unnormalized_angles,
"angles": a, "angles": angles,
"positions": pred_xyz, "positions": pred_xyz,
} }
outputs.append(preds) outputs.append(preds)
if i < (self.no_blocks - 1): if i < (self.no_blocks - 1):
t = t.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
...@@ -673,38 +677,42 @@ class StructureModule(nn.Module): ...@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.group_idx is None: if self.group_idx is None:
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
device=device, device=device,
requires_grad=False,
) )
if self.atom_mask is None: if self.atom_mask is None:
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.lit_positions is None: if self.lit_positions is None:
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames) return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N] self, r, f # [*, N, 8] # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device) self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
t, r,
f, f,
self.default_frames, self.default_frames,
self.group_idx, self.group_idx,
......
...@@ -22,7 +22,7 @@ from typing import Dict ...@@ -22,7 +22,7 @@ from typing import Dict
from openfold.np import protein from openfold.np import protein
import openfold.np.residue_constants as rc import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
batched_gather, batched_gather,
one_hot, one_hot,
...@@ -124,18 +124,16 @@ def build_template_pair_feat( ...@@ -124,18 +124,16 @@ def build_template_pair_feat(
) )
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision rigids = Rigid.make_transform_from_reference(
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :], n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :], ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :], c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps, eps=eps,
) )
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
points = affines.get_trans()[..., None, :, :] inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = ( template_mask = (
...@@ -144,7 +142,7 @@ def build_template_pair_feat( ...@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = affine_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch): ...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames( def torsion_angles_to_frames(
t: T, r: Rigid,
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
...@@ -176,13 +174,15 @@ def torsion_angles_to_frames( ...@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e. # [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and # One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix # One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4) default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1 bb_rot[..., 1] = 1
# [*, N, 8, 2] # [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3] # [*, N, 8, 3, 3]
# Produces rotation matrices of the form: # Produces rotation matrices of the form:
...@@ -194,15 +194,15 @@ def torsion_angles_to_frames( ...@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
all_rots = alpha.new_zeros(default_t.rots.shape) all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None) all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_t.compose(all_rots) all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -213,7 +213,7 @@ def torsion_angles_to_frames( ...@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat( all_frames_to_bb = Rigid.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1), chi2_frame_to_bb.unsqueeze(-1),
...@@ -223,13 +223,13 @@ def torsion_angles_to_frames( ...@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim=-1, dim=-1,
) )
all_frames_to_global = t[..., None].compose(all_frames_to_bb) all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
t: T, r: Rigid,
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames,
group_idx, group_idx,
...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
) )
# [*, N, 14, 8] # [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14] # [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
......
...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -74,8 +74,8 @@ def torsion_angle_loss( ...@@ -74,8 +74,8 @@ def torsion_angle_loss(
def compute_fape( def compute_fape(
pred_frames: T, pred_frames: Rigid,
target_frames: T, target_frames: Rigid,
frames_mask: torch.Tensor, frames_mask: torch.Tensor,
pred_positions: torch.Tensor, pred_positions: torch.Tensor,
target_positions: torch.Tensor, target_positions: torch.Tensor,
...@@ -111,7 +111,7 @@ def compute_fape( ...@@ -111,7 +111,7 @@ def compute_fape(
# ) # )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# #
# ("roughly" because eps is necessarily duplicated in the latter # ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1) normed_error = torch.sum(normed_error, dim=-1)
normed_error = ( normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
...@@ -123,8 +123,8 @@ def compute_fape( ...@@ -123,8 +123,8 @@ def compute_fape(
def backbone_loss( def backbone_loss(
backbone_affine_tensor: torch.Tensor, backbone_rigid_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor, backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor, traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0, clamp_distance: float = 10.0,
...@@ -132,16 +132,27 @@ def backbone_loss( ...@@ -132,16 +132,27 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_aff = T.from_tensor(traj) pred_aff = Rigid.from_tensor_7(traj)
gt_aff = T.from_tensor(backbone_affine_tensor) pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -150,10 +161,10 @@ def backbone_loss( ...@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -193,9 +204,9 @@ def sidechain_loss( ...@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames = sidechain_frames[-1] sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4] batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames) sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames) renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1] sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
...@@ -550,8 +561,8 @@ def compute_tm( ...@@ -550,8 +561,8 @@ def compute_tm(
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
backbone_affine_tensor, backbone_rigid_tensor,
backbone_affine_mask, backbone_rigid_mask,
resolution, resolution,
max_bin=31, max_bin=31,
no_bins=64, no_bins=64,
...@@ -560,16 +571,17 @@ def tm_loss( ...@@ -560,16 +571,17 @@ def tm_loss(
eps=1e-8, eps=1e-8,
**kwargs, **kwargs,
): ):
pred_affine = T.from_4x4(final_affine_tensor) pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine): def _points(affine):
pts = affine.get_trans()[..., None, :, :] pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts) return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum( sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1 (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
) )
sq_diff = sq_diff.detach() sq_diff = sq_diff.detach()
boundaries = torch.linspace( boundaries = torch.linspace(
...@@ -583,7 +595,7 @@ def tm_loss( ...@@ -583,7 +595,7 @@ def tm_loss(
) )
square_mask = ( square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :] backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
) )
loss = torch.sum(errors * square_mask, dim=-1) loss = torch.sum(errors * square_mask, dim=-1)
...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
), ),
} }
cum_loss = 0 cum_loss = 0.
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
......
...@@ -107,52 +107,790 @@ def rot_vec_mul( ...@@ -107,52 +107,790 @@ def rot_vec_mul(
) )
class T: def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
return rots
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros(
(*batch_dims, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat1[..., :, None, None] *
quat2[..., None, :, None],
dim=(-3, -2)
)
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat[..., :, None, None] *
vec[..., None, :, None],
dim=(-3, -2)
)
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def __init__(self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if((rot_mats is None and quats is None) or
(rot_mats is not None and quats is not None)):
raise ValueError("Exactly one input argument must be specified")
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
(quats is not None and quats.shape[-1] != 4)):
raise ValueError(
"Incorrectly shaped rotation matrix or quaternion"
)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if(fmt == "rot_mat"):
rot_mats = identity_rot_mats(
shape, dtype, device, requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif(fmt == "quat"):
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if(self._rot_mats is not None):
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif(self._quats is not None):
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if(self._rot_mats is not None):
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.dtype
elif(self._quats is not None):
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.device
elif(self._quats is not None):
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if(self._rot_mats is not None):
return self._rot_mats.requires_grad
elif(self._quats is not None):
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if(rot_mats is None):
if(self._quats is None):
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if(quats is None):
if(self._rot_mats is None):
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if(self._rot_mats is not None):
return self._rot_mats
elif(self._quats is not None):
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True
) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=invert_rot_mat(self._rot_mats),
quats=None
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if(self._rot_mats is not None):
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if(self._rot_mats is not None):
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = torch.stack(
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
""" """
A class representing an affine transformation. Essentially a wrapper Analogous to the cuda() method of torch Tensors
around two torch tensors: a [*, 3, 3] rotation and a [*, 3]
translation. Designed to behave approximately like a single torch Returns:
tensor with the shape of the shared dimensions of its component parts. A copy of the Rotation in CUDA memory
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.cuda(),
normalize_quats=False
)
else:
raise ValueError("Both rotations are None")
def to(self,
device: Optional[torch.device],
dtype: Optional[torch.dtype]
) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
""" """
def __init__(self, def __init__(self,
rots: torch.Tensor, rots: Optional[Rotation],
trans: torch.Tensor trans: Optional[torch.Tensor],
): ):
""" """
Args: Args:
rots: A [*, 3, 3] rotation tensor rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor trans: A corresponding [*, 3] translation tensor
""" """
self.rots = rots # (we need device, dtype, etc. from at least one input)
self.trans = trans
batch_dims, dtype, device, requires_grad = None, None, None, None
if self.rots is None and self.trans is None: if(trans is not None):
raise ValueError("Only one of rots and trans can be None") batch_dims = trans.shape[:-1]
elif self.rots is None: dtype = trans.dtype
self.rots = T._identity_rot( device = trans.device
self.trans.shape[:-1], requires_grad = trans.requires_grad
self.trans.dtype, elif(rots is not None):
self.trans.device, batch_dims = rots.shape
self.trans.requires_grad, dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if(rots is None):
rots = Rotation.identity(
batch_dims, dtype, device, requires_grad,
) )
elif self.trans is None: elif(trans is None):
self.trans = T._identity_trans( trans = identity_trans(
self.rots.shape[:-2], batch_dims, dtype, device, requires_grad,
self.rots.dtype,
self.rots.device,
self.rots.requires_grad,
) )
if ( if((rots.shape != trans.shape[:-1]) or
self.rots.shape[-2:] != (3, 3) (rots.device != trans.device)):
or self.trans.shape[-1] != 3 raise ValueError("Rots and trans incompatible")
or self.rots.shape[:-2] != self.trans.shape[:-1]
): self._rots = rots
raise ValueError("Incorrectly shaped input") self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self, def __getitem__(self,
index: Any, index: Any,
) -> T: ) -> Rigid:
""" """
Indexes the affine transformation with PyTorch-style indices. Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation The index is applied to the shared dimensions of both the rotation
...@@ -160,11 +898,12 @@ class T: ...@@ -160,11 +898,12 @@ class T:
E.g.:: E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3)) r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6] indexed = t[3, 4:6]
assert(indexed.shape == (2,)) assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2, 3, 3)) assert(indexed.get_rots().shape == (2,))
assert(indexed.trans.shape == (2, 3)) assert(indexed.get_trans().shape == (2, 3))
Args: Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3), index: A standard torch tensor index. E.g. 8, (10, None, 3),
...@@ -174,54 +913,45 @@ class T: ...@@ -174,54 +913,45 @@ class T:
""" """
if type(index) != tuple: if type(index) != tuple:
index = (index,) index = (index,)
return T(
self.rots[index + (slice(None), slice(None))],
self.trans[index + (slice(None),)],
)
def __eq__(self, return Rigid(
obj: T, self._rots[index],
) -> bool: self._trans[index + (slice(None),)],
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return bool(
torch.all(self.rots == obj.rots) and
torch.all(self.trans == obj.trans)
) )
def __mul__(self, def __mul__(self,
right: torch.Tensor, right: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Pointwise right multiplication of the affine transformation with a Pointwise left multiplication of the transformation with a tensor.
tensor. Multiplication is broadcast over the rotation/translation Can be used to e.g. mask the Rigid.
dimensions.
Args: Args:
right: The right multiplicand right:
The tensor multiplicand
Returns: Returns:
The product transformation The product
""" """
rots = self.rots * right[..., None, None] if not(isinstance(right, torch.Tensor)):
trans = self.trans * right[..., None] raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
return T(rots, trans) return Rigid(new_rots, new_trans)
def __rmul__(self, def __rmul__(self,
left: torch.Tensor, left: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Pointwise left multiplication of the affine transformation with a Reverse pointwise multiplication of the transformation with a
tensor. Multiplication is broadcast over the rotation/translation tensor.
dimensions.
Args: Args:
left: The left multiplicand left:
The left multiplicand
Returns: Returns:
The product transformation The product
""" """
return self.__mul__(left) return self.__mul__(left)
...@@ -234,45 +964,74 @@ class T: ...@@ -234,45 +964,74 @@ class T:
Returns: Returns:
The shape of the transformation The shape of the transformation
""" """
s = self.rots.shape[:-2] s = self._trans.shape[:-1]
return s if len(s) > 0 else torch.Size([1]) return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self): def get_rots(self) -> Rotation:
""" """
Getter for the rotation. Getter for the rotation.
Returns: Returns:
The stored rotation. The rotation object
""" """
return self.rots return self._rots
def get_trans(self) -> torch.Tensor: def get_trans(self) -> torch.Tensor:
""" """
Getter for the translation. Getter for the translation.
Returns: Returns:
The stored translation. The stored translation
""" """
return self.trans return self._trans
def compose(self, def compose_q_update_vec(self,
t: T, q_update_vec: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Composes the transformation with another. Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args: Args:
t: The inner transformation. q_vec: The quaternion update vector.
Returns: Returns:
The composed transformation. The composed transformation.
""" """
rot_1, trn_1 = self.rots, self.trans q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
rot_2, trn_2 = t.rots, t.trans new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
rot = rot_matmul(rot_1, rot_2) return Rigid(new_rots, new_translation)
trn = rot_vec_mul(rot_1, trn_2) + trn_1
return T(rot, trn) def compose(self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self, def apply(self,
pts: torch.Tensor, pts: torch.Tensor,
...@@ -285,9 +1044,8 @@ class T: ...@@ -285,9 +1044,8 @@ class T:
Returns: Returns:
The transformed points. The transformed points.
""" """
r, t = self.rots, self.trans rotated = self._rots.apply(pts)
rotated = rot_vec_mul(r, pts) return rotated + self._trans
return rotated + t
def invert_apply(self, def invert_apply(self,
pts: torch.Tensor pts: torch.Tensor
...@@ -300,99 +1058,60 @@ class T: ...@@ -300,99 +1058,60 @@ class T:
Returns: Returns:
The transformed points. The transformed points.
""" """
r, t = self.rots, self.trans pts = pts - self._trans
pts = pts - t return self._rots.invert_apply(pts)
return rot_vec_mul(r.transpose(-1, -2), pts)
def invert(self) -> T: def invert(self) -> Rigid:
""" """
Inverts the transformation. Inverts the transformation.
Returns: Returns:
The inverse transformation. The inverse transformation.
""" """
rot_inv = self.rots.transpose(-1, -2) rot_inv = self._rots.invert()
trn_inv = rot_vec_mul(rot_inv, self.trans) trn_inv = rot_inv.apply(self._trans)
return T(rot_inv, -1 * trn_inv) return Rigid(rot_inv, -1 * trn_inv)
def unsqueeze(self, def map_tensor_fn(self,
dim: int, fn: Callable[tensor.Tensor, tensor.Tensor]
) -> T: ) -> Rigid:
""" """
Analogous to torch.unsqueeze. The dimension is relative to the Apply a Tensor -> Tensor function to underlying translation and
shared dimensions of the rotation/translation. rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args: Args:
dim: A positive or negative dimension index. fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns: Returns:
The unsqueezed transformation. The transformed Rigid object
""" """
if dim >= len(self.shape): new_rots = self._rots.map_tensor_fn(fn)
raise ValueError("Invalid dimension") new_trans = torch.stack(
rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2) list(map(fn, torch.unbind(self._trans, dim=-1))),
trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1) dim=-1
return T(rots, trans)
@staticmethod
def _identity_rot(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
) )
rots = rots.view(*((1,) * len(shape)), 3, 3)
rots = rots.expand(*shape, -1, -1)
return rots
@staticmethod return Rigid(new_rots, new_trans)
def _identity_trans(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool
) -> torch.Tensor:
trans = torch.zeros(
(*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad
)
return trans
@staticmethod def to_tensor_4x4(self) -> torch.Tensor:
def identity(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool = True
) -> T:
""" """
Constructs an identity transformation. Converts a transformation to a homogenous transformation tensor.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns: Returns:
The identity transformation A [*, 4, 4] homogenous transformation tensor
""" """
return T( tensor = self._trans.new_zeros((*self.shape, 4, 4))
T._identity_rot(shape, dtype, device, requires_grad), tensor[..., :3, :3] = self._rots.get_rot_mats()
T._identity_trans(shape, dtype, device, requires_grad), tensor[..., :3, 3] = self._trans
) tensor[..., 3, 3] = 1
return tensor
@staticmethod @staticmethod
def from_4x4( def from_tensor_4x4(
t: torch.Tensor t: torch.Tensor
) -> T: ) -> Rigid:
""" """
Constructs a transformation from a homogenous transformation Constructs a transformation from a homogenous transformation
tensor. tensor.
...@@ -402,35 +1121,45 @@ class T: ...@@ -402,35 +1121,45 @@ class T:
Returns: Returns:
T object with shape [*] T object with shape [*]
""" """
rots = t[..., :3, :3] if(t.shape[-2:] != (4, 4)):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3] trans = t[..., :3, 3]
return T(rots, trans)
def to_4x4(self) -> torch.Tensor: return Rigid(rots, trans)
def to_tensor_7(self) -> torch.Tensor:
""" """
Converts a transformation to a homogenous transformation tensor. Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns: Returns:
A [*, 4, 4] homogenous transformation tensor A [*, 7] tensor representation of the transformation
""" """
tensor = self.rots.new_zeros((*self.shape, 4, 4)) tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :3, :3] = self.rots tensor[..., :4] = self._rots.get_quats()
tensor[..., :3, 3] = self.trans tensor[..., 4:] = self._trans
tensor[..., 3, 3] = 1
return tensor return tensor
@staticmethod @staticmethod
def from_tensor(t: torch.Tensor) -> T: def from_tensor_7(
""" t: torch.Tensor,
Constructs a transformation from a homogenous transformation normalize_quats: bool = False,
tensor. ) -> Rigid:
if(t.shape[-1] != 7):
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(
rot_mats=None,
quats=quats,
normalize_quats=normalize_quats
)
Args: return Rigid(rots, trans)
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return T.from_4x4(t)
@staticmethod @staticmethod
def from_3_points( def from_3_points(
...@@ -438,7 +1167,7 @@ class T: ...@@ -438,7 +1167,7 @@ class T:
origin: torch.Tensor, origin: torch.Tensor,
p_xy_plane: torch.Tensor, p_xy_plane: torch.Tensor,
eps: float = 1e-8 eps: float = 1e-8
) -> T: ) -> Rigid:
""" """
Implements algorithm 21. Constructs transformations from sets of 3 Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm. points using the Gram-Schmidt algorithm.
...@@ -473,13 +1202,34 @@ class T: ...@@ -473,13 +1202,34 @@ class T:
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3)) rots = rots.reshape(rots.shape[:-1] + (3, 3))
return T(rots, torch.stack(origin, dim=-1)) rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod @staticmethod
def concat( def cat(
ts: Sequence[T], ts: Sequence[Rigid],
dim: int, dim: int,
) -> T: ) -> Rigid:
""" """
Concatenates transformations along a new dimension. Concatenates transformations along a new dimension.
...@@ -492,57 +1242,60 @@ class T: ...@@ -492,57 +1242,60 @@ class T:
Returns: Returns:
A concatenated transformation object A concatenated transformation object
""" """
rots = torch.cat([t.rots for t in ts], dim=dim if dim >= 0 else dim - 2) rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat( trans = torch.cat(
[t.trans for t in ts], dim=dim if dim >= 0 else dim - 1 [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
) )
return T(rots, trans) return Rigid(rots, trans)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> T: def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
""" """
Apply a function that takes a tensor as its only argument to the Applies a Rotation -> Rotation function to the stored rotation
rotations and translations, treating the final two/one object.
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
Args: Args:
fn: A function that takes only a tensor as its argument fn: A function of type Rotation -> Rotation
Returns: Returns:
The transformed transformation object. A transformation object with a transformed rotation.
""" """
rots = self.rots.view(*self.rots.shape[:-2], 9) return Rigid(fn(self._rots), self._trans)
rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1)
rots = rots.view(*rots.shape[:-1], 3, 3)
trans = torch.stack(list(map(fn, torch.unbind(self.trans, -1))), dim=-1) def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
return T(rots, trans) Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def stop_rot_gradient(self) -> T: def scale_translation(self, trans_scale_factor: float) -> Rigid:
""" """
Detaches the contained rotation tensor. Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns: Returns:
A version of the transformation with detached rotations A transformation object with a scaled translation.
""" """
return T(self.rots.detach(), self.trans) fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def scale_translation(self, factor: int) -> T: def stop_rot_gradient(self) -> Rigid:
""" """
Scales the contained translation tensor by a constant factor. Detaches the underlying rotation object
Returns: Returns:
A version of the transformation with scaled translations A transformation object with detached rotations
""" """
return T(self.rots, self.trans * factor) fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod @staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
...@@ -613,87 +1366,15 @@ class T: ...@@ -613,87 +1366,15 @@ class T:
rots = rots.transpose(-1, -2) rots = rots.transpose(-1, -2)
translation = -1 * translation translation = -1 * translation
return T(rots, translation) rot_obj = Rotation(rot_mats=rots, quats=None)
def cuda(self) -> T: return Rigid(rot_obj, translation)
def cuda(self) -> Rigid:
""" """
Moves the transformation object to GPU memory Moves the transformation object to GPU memory
Returns: Returns:
A version of the transformation on GPU A version of the transformation on GPU
""" """
return T(self.rots.cuda(), self.trans.cuda()) return Rigid(self._rots.cuda(), self._trans.cuda())
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_qtr_mat[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_qtr_mat[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_qtr_mat[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_qtr_mat[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_qtr_mat[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_qtr_mat[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_qtr_mat[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector: torch.Tensor) -> torch.Tensor:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import ( ...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n = 5 n = 5
rots = torch.rand((batch_size, n, 3, 3)) rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3)) trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2)) angles = torch.rand((batch_size, n, 7, 2))
...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase): ...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2) torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase): ...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row[..., 3] = 1 bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2) transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu() transforms_repro = out.to_tensor_4x4().cpu()
self.assertTrue( self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps) torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase): ...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3)) rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3)) trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase): ...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res, 8)) affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids) out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt) jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......
...@@ -20,7 +20,10 @@ import unittest ...@@ -20,7 +20,10 @@ import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
torsion_angle_loss, torsion_angle_loss,
...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed(): ...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import haiku as hk import haiku as hk
def affine_vector_to_4x4(affine):
r = Rigid.from_tensor_7(affine)
return r.to_tensor_4x4()
class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self): def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase): ...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt = torch.rand((batch_size, n_frames, 3, 3)) rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3)) trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3)) trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
t_gt = T(rots_gt, trans_gt) t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float() frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float() positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10 length_scale = 10
...@@ -686,10 +694,10 @@ class TestLoss(unittest.TestCase): ...@@ -686,10 +694,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["traj"] = affine_vector_to_4x4(value["traj"]) batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm}) out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase): ...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f = hk.transform(run_tm_loss) f = hk.transform(run_tm_loss)
np.random.seed(42)
n_res = consts.n_res n_res = consts.n_res
representations = { representations = {
...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase): ...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["structure_module"]["final_affines"] = affine_vector_to_4x4( batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
value["structure_module"]["final_affines"]
)
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"]) logits = model.aux_heads.tm(representations["pair"])
......
...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase): ...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3)) print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -31,8 +31,8 @@ from openfold.model.structure_module import ( ...@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f) out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4)) self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase): ...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans) r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp
) )
shape_before = s.shape shape_before = s.shape
s = ipa(s, z, t, mask) s = ipa(s, z, r, mask)
self.assertTrue(s.shape == shape_before) self.assertTrue(s.shape == shape_before)
...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids) quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats sample_affine = quats
......
...@@ -13,11 +13,24 @@ ...@@ -13,11 +13,24 @@
# limitations under the License. # limitations under the License.
import math import math
import numpy as np
import torch import torch
import unittest import unittest
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor( X_90_ROT = torch.tensor(
...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor( ...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self): def test_rigid_from_3_points_shape(self):
batch_size = 2 batch_size = 2
n_res = 5 n_res = 5
...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase): ...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2 = torch.rand((batch_size, n_res, 3)) x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3)) x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3) r = Rigid.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3)) self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2)) self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self): def test_rigid_from_4x4(self):
batch_size = 2 batch_size = 2
transf = [ transf = [
[1, 0, 0, 1], [1, 0, 0, 1],
...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase): ...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf = torch.stack([transf for _ in range(batch_size)], dim=0) transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf) r = Rigid.from_tensor_4x4(transf)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0))) self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0))) self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self): def test_rigid_shape(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
self.assertTrue(transf.shape == (batch_size, n)) self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self): def test_rigid_cat(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
transf_concat = T.concat([transf, transf], dim=0) transf_cat = Rigid.cat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3)) transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
transf_concat = T.concat([transf, transf], dim=1) self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3)) transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots)) self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
def test_T_compose(self): self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0] trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1] trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1)) r = Rotation(rot_mats=X_90_ROT)
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2)) t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2) t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3))) self.assertTrue(
self.assertTrue(torch.all(t3.trans == 0)) torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_T_apply(self): def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0) rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1]) trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0) trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30) x = torch.arange(30)
x = torch.stack([x, x], dim=0) x = torch.stack([x, x], dim=0)
...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase): ...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps = 1e-07 eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps)) self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self): def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15) x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30) l = torch.nn.Linear(15, 30)
...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase): ...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened = x_flat[i:j] chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened)) self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment