Commit e3daf724 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Overhaul transformation code for better parity w/ AlphaFold

parent 1f709b0d
...@@ -89,8 +89,8 @@ config = mlc.ConfigDict( ...@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists": [NUM_RES, None], "atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None], "atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None], "atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES], "backbone_rigid_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None], "backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES], "bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None], "chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None], "chi_mask": [NUM_RES, None],
...@@ -126,8 +126,8 @@ config = mlc.ConfigDict( ...@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos": [ "template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES], "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [ "template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_mask": [NUM_TEMPLATES], "template_mask": [NUM_TEMPLATES],
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -752,7 +752,7 @@ def make_atom14_positions(protein): ...@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return protein return protein
def atom37_to_frames(protein): def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
...@@ -810,11 +810,11 @@ def atom37_to_frames(protein): ...@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
gt_frames = T.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :], p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8, eps=eps,
) )
group_exists = batched_gather( group_exists = batched_gather(
...@@ -836,8 +836,9 @@ def atom37_to_frames(protein): ...@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(T(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
...@@ -871,10 +872,15 @@ def atom37_to_frames(protein): ...@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None)) residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_4x4() gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists protein["rigidgroups_gt_exists"] = gt_exists
...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles( ...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim=-1, dim=-1,
) )
torsion_frames = T.from_3_points( torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :], torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :], torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :], torsions_atom_pos[..., 0, :],
...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles( ...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein): def get_backbone_frames(protein):
# TODO: Verify that this is correct # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, : ..., 0, :, :
] ]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0] protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein return protein
......
...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool: ...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
...@@ -475,7 +477,7 @@ def get_atom_coords( ...@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
if zero_center: if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool) binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0) translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec all_atom_positions[binary_mask] -= translation_vec
......
...@@ -503,10 +503,13 @@ def _get_atom_positions( ...@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
) )
all_atom_positions, all_atom_mask = coords_with_mask all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances( _check_residue_distances(
...@@ -523,6 +526,7 @@ def _extract_template_features( ...@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence: str, query_sequence: str,
template_chain_id: str, template_chain_id: str,
kalign_binary_path: str, kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]: ) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query. """Parses atom positions in the target structure and aligns with the query.
...@@ -607,7 +611,10 @@ def _extract_template_features( ...@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless # Essentially set to infinity - we don't want to reject templates unless
# they're really really bad. # they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions( all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0 mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
) )
except (CaDistanceError, KeyError) as ex: except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError( raise NoAtomDataInTemplateError(
...@@ -795,6 +802,7 @@ def _process_single_hit( ...@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str, kalign_binary_path: str,
strict_error_check: bool = False, strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult: ) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit.""" """Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
...@@ -856,6 +864,7 @@ def _process_single_hit( ...@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs] features["template_sum_probs"] = [hit.sum_probs]
...@@ -913,7 +922,6 @@ class TemplateSearchResult: ...@@ -913,7 +922,6 @@ class TemplateSearchResult:
class TemplateHitFeaturizer: class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features.""" """A class for turning hhr hits to template features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer: ...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path: Optional[str] = None, obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False, strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None, _shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
): ):
"""Initializes the Template Search. """Initializes the Template Search.
...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer: ...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates( def get_templates(
self, self,
...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer: ...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path, kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
) )
if result.error: if result.error:
......
...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
device=x.device, device=x.device,
requires_grad=False,
) )
# [*, N, C_m] # [*, N, C_m]
......
...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import ( ...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import ( from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
) )
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
permute_final_dims, permute_final_dims,
...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module): ...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
t: T, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module): ...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
t: r:
[*, N_res] affine transformation object [*, N_res] transformation object
mask: mask:
[*, N_res] mask [*, N_res] mask
Returns: Returns:
...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module): ...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3] # [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1) q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts) q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_q, 3]
q_pts = q_pts.view( q_pts = q_pts.view(
...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module): ...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3] # [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1) kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = t[..., None].apply(kv_pts) kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3] # [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module): ...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt) o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims( o_pt_norm = flatten_final_dims(
...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module): ...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class BackboneUpdate(nn.Module): class BackboneUpdate(nn.Module):
""" """
Implements Algorithm 23. Implements part of Algorithm 23.
""" """
def __init__(self, c_s): def __init__(self, c_s):
...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module): ...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self.linear = Linear(self.c_s, 6, init="final") self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s): def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
Returns: Returns:
[*, N_res] affine transformation object [*, N_res, 6] update vector
""" """
# [*, 6] # [*, 6]
params = self.linear(s) update = self.linear(s)
# [*, 3]
quats, trans = params[..., :3], params[..., 3:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom[..., None]
# [*, 3, 3] return update
rots = quat_to_rot(quats)
return T(rots, trans)
class StructureModuleTransitionLayer(nn.Module): class StructureModuleTransitionLayer(nn.Module):
...@@ -592,7 +573,7 @@ class StructureModule(nn.Module): ...@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self, self,
s, s,
z, z,
f, aatype,
mask=None, mask=None,
): ):
""" """
...@@ -601,7 +582,7 @@ class StructureModule(nn.Module): ...@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
f: aatype:
[*, N_res] amino acid indices [*, N_res] amino acid indices
mask: mask:
Optional [*, N_res] sequence mask Optional [*, N_res] sequence mask
...@@ -623,44 +604,67 @@ class StructureModule(nn.Module): ...@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s = self.linear_in(s) s = self.linear_in(s)
# [*, N] # [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training) rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, t, mask) s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
t = t.compose(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2] # [*, N, 7, 2]
unnormalized_a, a = self.angle_resnet(s, s_initial) unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames( all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), backb_to_global,
a, angles,
f, aatype,
) )
pred_xyz = self.frames_and_literature_positions_to_atom14_pos( pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global, all_frames_to_global,
f, aatype,
) )
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = { preds = {
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(), "frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_4x4(), "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_a, "unnormalized_angles": unnormalized_angles,
"angles": a, "angles": angles,
"positions": pred_xyz, "positions": pred_xyz,
} }
outputs.append(preds) outputs.append(preds)
if i < (self.no_blocks - 1): if i < (self.no_blocks - 1):
t = t.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
...@@ -673,38 +677,42 @@ class StructureModule(nn.Module): ...@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.group_idx is None: if self.group_idx is None:
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
device=device, device=device,
requires_grad=False,
) )
if self.atom_mask is None: if self.atom_mask is None:
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.lit_positions is None: if self.lit_positions is None:
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames) return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N] self, r, f # [*, N, 8] # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device) self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
t, r,
f, f,
self.default_frames, self.default_frames,
self.group_idx, self.group_idx,
......
...@@ -22,7 +22,7 @@ from typing import Dict ...@@ -22,7 +22,7 @@ from typing import Dict
from openfold.np import protein from openfold.np import protein
import openfold.np.residue_constants as rc import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
batched_gather, batched_gather,
one_hot, one_hot,
...@@ -124,18 +124,16 @@ def build_template_pair_feat( ...@@ -124,18 +124,16 @@ def build_template_pair_feat(
) )
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision rigids = Rigid.make_transform_from_reference(
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :], n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :], ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :], c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps, eps=eps,
) )
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
points = affines.get_trans()[..., None, :, :] inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = ( template_mask = (
...@@ -144,7 +142,7 @@ def build_template_pair_feat( ...@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = affine_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch): ...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames( def torsion_angles_to_frames(
t: T, r: Rigid,
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
...@@ -176,13 +174,15 @@ def torsion_angles_to_frames( ...@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e. # [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and # One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix # One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4) default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1 bb_rot[..., 1] = 1
# [*, N, 8, 2] # [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3] # [*, N, 8, 3, 3]
# Produces rotation matrices of the form: # Produces rotation matrices of the form:
...@@ -194,15 +194,15 @@ def torsion_angles_to_frames( ...@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
all_rots = alpha.new_zeros(default_t.rots.shape) all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None) all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_t.compose(all_rots) all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -213,7 +213,7 @@ def torsion_angles_to_frames( ...@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat( all_frames_to_bb = Rigid.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1), chi2_frame_to_bb.unsqueeze(-1),
...@@ -223,13 +223,13 @@ def torsion_angles_to_frames( ...@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim=-1, dim=-1,
) )
all_frames_to_global = t[..., None].compose(all_frames_to_bb) all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
t: T, r: Rigid,
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames,
group_idx, group_idx,
...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
) )
# [*, N, 14, 8] # [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14] # [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
......
...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -74,8 +74,8 @@ def torsion_angle_loss( ...@@ -74,8 +74,8 @@ def torsion_angle_loss(
def compute_fape( def compute_fape(
pred_frames: T, pred_frames: Rigid,
target_frames: T, target_frames: Rigid,
frames_mask: torch.Tensor, frames_mask: torch.Tensor,
pred_positions: torch.Tensor, pred_positions: torch.Tensor,
target_positions: torch.Tensor, target_positions: torch.Tensor,
...@@ -111,7 +111,7 @@ def compute_fape( ...@@ -111,7 +111,7 @@ def compute_fape(
# ) # )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# #
# ("roughly" because eps is necessarily duplicated in the latter # ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1) normed_error = torch.sum(normed_error, dim=-1)
normed_error = ( normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
...@@ -123,8 +123,8 @@ def compute_fape( ...@@ -123,8 +123,8 @@ def compute_fape(
def backbone_loss( def backbone_loss(
backbone_affine_tensor: torch.Tensor, backbone_rigid_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor, backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor, traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0, clamp_distance: float = 10.0,
...@@ -132,16 +132,27 @@ def backbone_loss( ...@@ -132,16 +132,27 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_aff = T.from_tensor(traj) pred_aff = Rigid.from_tensor_7(traj)
gt_aff = T.from_tensor(backbone_affine_tensor) pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -150,10 +161,10 @@ def backbone_loss( ...@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -193,9 +204,9 @@ def sidechain_loss( ...@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames = sidechain_frames[-1] sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4] batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames) sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames) renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1] sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
...@@ -422,7 +433,7 @@ def distogram_loss( ...@@ -422,7 +433,7 @@ def distogram_loss(
device=logits.device, device=logits.device,
) )
boundaries = boundaries ** 2 boundaries = boundaries ** 2
dists = torch.sum( dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1, dim=-1,
...@@ -550,8 +561,8 @@ def compute_tm( ...@@ -550,8 +561,8 @@ def compute_tm(
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
backbone_affine_tensor, backbone_rigid_tensor,
backbone_affine_mask, backbone_rigid_mask,
resolution, resolution,
max_bin=31, max_bin=31,
no_bins=64, no_bins=64,
...@@ -560,16 +571,17 @@ def tm_loss( ...@@ -560,16 +571,17 @@ def tm_loss(
eps=1e-8, eps=1e-8,
**kwargs, **kwargs,
): ):
pred_affine = T.from_4x4(final_affine_tensor) pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine): def _points(affine):
pts = affine.get_trans()[..., None, :, :] pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts) return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum( sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1 (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
) )
sq_diff = sq_diff.detach() sq_diff = sq_diff.detach()
boundaries = torch.linspace( boundaries = torch.linspace(
...@@ -583,7 +595,7 @@ def tm_loss( ...@@ -583,7 +595,7 @@ def tm_loss(
) )
square_mask = ( square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :] backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
) )
loss = torch.sum(errors * square_mask, dim=-1) loss = torch.sum(errors * square_mask, dim=-1)
...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
), ),
} }
cum_loss = 0 cum_loss = 0.
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
......
...@@ -106,53 +106,791 @@ def rot_vec_mul( ...@@ -106,53 +106,791 @@ def rot_vec_mul(
dim=-1, dim=-1,
) )
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
return rots
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros(
(*batch_dims, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat1[..., :, None, None] *
quat2[..., None, :, None],
dim=(-3, -2)
)
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat[..., :, None, None] *
vec[..., None, :, None],
dim=(-3, -2)
)
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
class T: def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
return inv
class Rotation:
""" """
A class representing an affine transformation. Essentially a wrapper A 3D rotation. Depending on how the object is initialized, the
around two torch tensors: a [*, 3, 3] rotation and a [*, 3] rotation is represented by either a rotation matrix or a
translation. Designed to behave approximately like a single torch quaternion, though both formats are made available by helper functions.
tensor with the shape of the shared dimensions of its component parts. To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def __init__(self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if((rot_mats is None and quats is None) or
(rot_mats is not None and quats is not None)):
raise ValueError("Exactly one input argument must be specified")
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
(quats is not None and quats.shape[-1] != 4)):
raise ValueError(
"Incorrectly shaped rotation matrix or quaternion"
)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if(fmt == "rot_mat"):
rot_mats = identity_rot_mats(
shape, dtype, device, requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif(fmt == "quat"):
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if(self._rot_mats is not None):
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif(self._quats is not None):
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if(self._rot_mats is not None):
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.dtype
elif(self._quats is not None):
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.device
elif(self._quats is not None):
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if(self._rot_mats is not None):
return self._rot_mats.requires_grad
elif(self._quats is not None):
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if(rot_mats is None):
if(self._quats is None):
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if(quats is None):
if(self._rot_mats is None):
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if(self._rot_mats is not None):
return self._rot_mats
elif(self._quats is not None):
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True
) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=invert_rot_mat(self._rot_mats),
quats=None
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if(self._rot_mats is not None):
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if(self._rot_mats is not None):
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = torch.stack(
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.cuda(),
normalize_quats=False
)
else:
raise ValueError("Both rotations are None")
def to(self,
device: Optional[torch.device],
dtype: Optional[torch.dtype]
) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
""" """
def __init__(self, def __init__(self,
rots: torch.Tensor, rots: Optional[Rotation],
trans: torch.Tensor trans: Optional[torch.Tensor],
): ):
""" """
Args: Args:
rots: A [*, 3, 3] rotation tensor rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor trans: A corresponding [*, 3] translation tensor
""" """
self.rots = rots # (we need device, dtype, etc. from at least one input)
self.trans = trans
batch_dims, dtype, device, requires_grad = None, None, None, None
if self.rots is None and self.trans is None: if(trans is not None):
raise ValueError("Only one of rots and trans can be None") batch_dims = trans.shape[:-1]
elif self.rots is None: dtype = trans.dtype
self.rots = T._identity_rot( device = trans.device
self.trans.shape[:-1], requires_grad = trans.requires_grad
self.trans.dtype, elif(rots is not None):
self.trans.device, batch_dims = rots.shape
self.trans.requires_grad, dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if(rots is None):
rots = Rotation.identity(
batch_dims, dtype, device, requires_grad,
) )
elif self.trans is None: elif(trans is None):
self.trans = T._identity_trans( trans = identity_trans(
self.rots.shape[:-2], batch_dims, dtype, device, requires_grad,
self.rots.dtype,
self.rots.device,
self.rots.requires_grad,
) )
if ( if((rots.shape != trans.shape[:-1]) or
self.rots.shape[-2:] != (3, 3) (rots.device != trans.device)):
or self.trans.shape[-1] != 3 raise ValueError("Rots and trans incompatible")
or self.rots.shape[:-2] != self.trans.shape[:-1]
): self._rots = rots
raise ValueError("Incorrectly shaped input") self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self, def __getitem__(self,
index: Any, index: Any,
) -> T: ) -> Rigid:
""" """
Indexes the affine transformation with PyTorch-style indices. Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation The index is applied to the shared dimensions of both the rotation
...@@ -160,11 +898,12 @@ class T: ...@@ -160,11 +898,12 @@ class T:
E.g.:: E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3)) r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6] indexed = t[3, 4:6]
assert(indexed.shape == (2,)) assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2, 3, 3)) assert(indexed.get_rots().shape == (2,))
assert(indexed.trans.shape == (2, 3)) assert(indexed.get_trans().shape == (2, 3))
Args: Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3), index: A standard torch tensor index. E.g. 8, (10, None, 3),
...@@ -174,54 +913,45 @@ class T: ...@@ -174,54 +913,45 @@ class T:
""" """
if type(index) != tuple: if type(index) != tuple:
index = (index,) index = (index,)
return T(
self.rots[index + (slice(None), slice(None))], return Rigid(
self.trans[index + (slice(None),)], self._rots[index],
) self._trans[index + (slice(None),)],
def __eq__(self,
obj: T,
) -> bool:
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return bool(
torch.all(self.rots == obj.rots) and
torch.all(self.trans == obj.trans)
) )
def __mul__(self, def __mul__(self,
right: torch.Tensor, right: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Pointwise right multiplication of the affine transformation with a Pointwise left multiplication of the transformation with a tensor.
tensor. Multiplication is broadcast over the rotation/translation Can be used to e.g. mask the Rigid.
dimensions.
Args: Args:
right: The right multiplicand right:
The tensor multiplicand
Returns: Returns:
The product transformation The product
""" """
rots = self.rots * right[..., None, None] if not(isinstance(right, torch.Tensor)):
trans = self.trans * right[..., None] raise TypeError("The other multiplicand must be a Tensor")
return T(rots, trans) new_rots = self._rots * right
new_trans = self._trans * right[..., None]
def __rmul__(self, return Rigid(new_rots, new_trans)
def __rmul__(self,
left: torch.Tensor, left: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Pointwise left multiplication of the affine transformation with a Reverse pointwise multiplication of the transformation with a
tensor. Multiplication is broadcast over the rotation/translation tensor.
dimensions.
Args: Args:
left: The left multiplicand left:
The left multiplicand
Returns: Returns:
The product transformation The product
""" """
return self.__mul__(left) return self.__mul__(left)
...@@ -234,45 +964,74 @@ class T: ...@@ -234,45 +964,74 @@ class T:
Returns: Returns:
The shape of the transformation The shape of the transformation
""" """
s = self.rots.shape[:-2] s = self._trans.shape[:-1]
return s if len(s) > 0 else torch.Size([1]) return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
def get_rots(self): Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
""" """
Getter for the rotation. Getter for the rotation.
Returns: Returns:
The stored rotation. The rotation object
""" """
return self.rots return self._rots
def get_trans(self) -> torch.Tensor: def get_trans(self) -> torch.Tensor:
""" """
Getter for the translation. Getter for the translation.
Returns: Returns:
The stored translation. The stored translation
""" """
return self.trans return self._trans
def compose(self, def compose_q_update_vec(self,
t: T, q_update_vec: torch.Tensor,
) -> T: ) -> Rigid:
""" """
Composes the transformation with another. Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args: Args:
t: The inner transformation. q_vec: The quaternion update vector.
Returns: Returns:
The composed transformation. The composed transformation.
""" """
rot_1, trn_1 = self.rots, self.trans q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
rot_2, trn_2 = t.rots, t.trans new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
rot = rot_matmul(rot_1, rot_2) return Rigid(new_rots, new_translation)
trn = rot_vec_mul(rot_1, trn_2) + trn_1
return T(rot, trn) def compose(self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self, def apply(self,
pts: torch.Tensor, pts: torch.Tensor,
...@@ -285,9 +1044,8 @@ class T: ...@@ -285,9 +1044,8 @@ class T:
Returns: Returns:
The transformed points. The transformed points.
""" """
r, t = self.rots, self.trans rotated = self._rots.apply(pts)
rotated = rot_vec_mul(r, pts) return rotated + self._trans
return rotated + t
def invert_apply(self, def invert_apply(self,
pts: torch.Tensor pts: torch.Tensor
...@@ -300,99 +1058,60 @@ class T: ...@@ -300,99 +1058,60 @@ class T:
Returns: Returns:
The transformed points. The transformed points.
""" """
r, t = self.rots, self.trans pts = pts - self._trans
pts = pts - t return self._rots.invert_apply(pts)
return rot_vec_mul(r.transpose(-1, -2), pts)
def invert(self) -> T: def invert(self) -> Rigid:
""" """
Inverts the transformation. Inverts the transformation.
Returns: Returns:
The inverse transformation. The inverse transformation.
""" """
rot_inv = self.rots.transpose(-1, -2) rot_inv = self._rots.invert()
trn_inv = rot_vec_mul(rot_inv, self.trans) trn_inv = rot_inv.apply(self._trans)
return T(rot_inv, -1 * trn_inv) return Rigid(rot_inv, -1 * trn_inv)
def unsqueeze(self, def map_tensor_fn(self,
dim: int, fn: Callable[tensor.Tensor, tensor.Tensor]
) -> T: ) -> Rigid:
""" """
Analogous to torch.unsqueeze. The dimension is relative to the Apply a Tensor -> Tensor function to underlying translation and
shared dimensions of the rotation/translation. rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args: Args:
dim: A positive or negative dimension index. fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns: Returns:
The unsqueezed transformation. The transformed Rigid object
""" """
if dim >= len(self.shape): new_rots = self._rots.map_tensor_fn(fn)
raise ValueError("Invalid dimension") new_trans = torch.stack(
rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2) list(map(fn, torch.unbind(self._trans, dim=-1))),
trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1) dim=-1
return T(rots, trans)
@staticmethod
def _identity_rot(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
) )
rots = rots.view(*((1,) * len(shape)), 3, 3)
rots = rots.expand(*shape, -1, -1)
return rots
@staticmethod return Rigid(new_rots, new_trans)
def _identity_trans(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool
) -> torch.Tensor:
trans = torch.zeros(
(*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad
)
return trans
@staticmethod def to_tensor_4x4(self) -> torch.Tensor:
def identity(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool = True
) -> T:
""" """
Constructs an identity transformation. Converts a transformation to a homogenous transformation tensor.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns: Returns:
The identity transformation A [*, 4, 4] homogenous transformation tensor
""" """
return T( tensor = self._trans.new_zeros((*self.shape, 4, 4))
T._identity_rot(shape, dtype, device, requires_grad), tensor[..., :3, :3] = self._rots.get_rot_mats()
T._identity_trans(shape, dtype, device, requires_grad), tensor[..., :3, 3] = self._trans
) tensor[..., 3, 3] = 1
return tensor
@staticmethod @staticmethod
def from_4x4( def from_tensor_4x4(
t: torch.Tensor t: torch.Tensor
) -> T: ) -> Rigid:
""" """
Constructs a transformation from a homogenous transformation Constructs a transformation from a homogenous transformation
tensor. tensor.
...@@ -402,35 +1121,45 @@ class T: ...@@ -402,35 +1121,45 @@ class T:
Returns: Returns:
T object with shape [*] T object with shape [*]
""" """
rots = t[..., :3, :3] if(t.shape[-2:] != (4, 4)):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3] trans = t[..., :3, 3]
return T(rots, trans)
return Rigid(rots, trans)
def to_4x4(self) -> torch.Tensor: def to_tensor_7(self) -> torch.Tensor:
""" """
Converts a transformation to a homogenous transformation tensor. Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns: Returns:
A [*, 4, 4] homogenous transformation tensor A [*, 7] tensor representation of the transformation
""" """
tensor = self.rots.new_zeros((*self.shape, 4, 4)) tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :3, :3] = self.rots tensor[..., :4] = self._rots.get_quats()
tensor[..., :3, 3] = self.trans tensor[..., 4:] = self._trans
tensor[..., 3, 3] = 1
return tensor return tensor
@staticmethod @staticmethod
def from_tensor(t: torch.Tensor) -> T: def from_tensor_7(
""" t: torch.Tensor,
Constructs a transformation from a homogenous transformation normalize_quats: bool = False,
tensor. ) -> Rigid:
if(t.shape[-1] != 7):
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(
rot_mats=None,
quats=quats,
normalize_quats=normalize_quats
)
Args: return Rigid(rots, trans)
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return T.from_4x4(t)
@staticmethod @staticmethod
def from_3_points( def from_3_points(
...@@ -438,7 +1167,7 @@ class T: ...@@ -438,7 +1167,7 @@ class T:
origin: torch.Tensor, origin: torch.Tensor,
p_xy_plane: torch.Tensor, p_xy_plane: torch.Tensor,
eps: float = 1e-8 eps: float = 1e-8
) -> T: ) -> Rigid:
""" """
Implements algorithm 21. Constructs transformations from sets of 3 Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm. points using the Gram-Schmidt algorithm.
...@@ -473,13 +1202,34 @@ class T: ...@@ -473,13 +1202,34 @@ class T:
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3)) rots = rots.reshape(rots.shape[:-1] + (3, 3))
return T(rots, torch.stack(origin, dim=-1)) rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod @staticmethod
def concat( def cat(
ts: Sequence[T], ts: Sequence[Rigid],
dim: int, dim: int,
) -> T: ) -> Rigid:
""" """
Concatenates transformations along a new dimension. Concatenates transformations along a new dimension.
...@@ -492,57 +1242,60 @@ class T: ...@@ -492,57 +1242,60 @@ class T:
Returns: Returns:
A concatenated transformation object A concatenated transformation object
""" """
rots = torch.cat([t.rots for t in ts], dim=dim if dim >= 0 else dim - 2) rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat( trans = torch.cat(
[t.trans for t in ts], dim=dim if dim >= 0 else dim - 1 [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
) )
return T(rots, trans) return Rigid(rots, trans)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> T: def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
""" """
Apply a function that takes a tensor as its only argument to the Applies a Rotation -> Rotation function to the stored rotation
rotations and translations, treating the final two/one object.
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
Args: Args:
fn: A function that takes only a tensor as its argument fn: A function of type Rotation -> Rotation
Returns: Returns:
The transformed transformation object. A transformation object with a transformed rotation.
""" """
rots = self.rots.view(*self.rots.shape[:-2], 9) return Rigid(fn(self._rots), self._trans)
rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1)
rots = rots.view(*rots.shape[:-1], 3, 3)
trans = torch.stack(list(map(fn, torch.unbind(self.trans, -1))), dim=-1) def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
return T(rots, trans) Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def stop_rot_gradient(self) -> T: def scale_translation(self, trans_scale_factor: float) -> Rigid:
""" """
Detaches the contained rotation tensor. Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns: Returns:
A version of the transformation with detached rotations A transformation object with a scaled translation.
""" """
return T(self.rots.detach(), self.trans) fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def scale_translation(self, factor: int) -> T: def stop_rot_gradient(self) -> Rigid:
""" """
Scales the contained translation tensor by a constant factor. Detaches the underlying rotation object
Returns: Returns:
A version of the transformation with scaled translations A transformation object with detached rotations
""" """
return T(self.rots, self.trans * factor) fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod @staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
...@@ -613,87 +1366,15 @@ class T: ...@@ -613,87 +1366,15 @@ class T:
rots = rots.transpose(-1, -2) rots = rots.transpose(-1, -2)
translation = -1 * translation translation = -1 * translation
return T(rots, translation) rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> T: def cuda(self) -> Rigid:
""" """
Moves the transformation object to GPU memory Moves the transformation object to GPU memory
Returns: Returns:
A version of the transformation on GPU A version of the transformation on GPU
""" """
return T(self.rots.cuda(), self.trans.cuda()) return Rigid(self._rots.cuda(), self._trans.cuda())
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_qtr_mat[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_qtr_mat[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_qtr_mat[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_qtr_mat[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_qtr_mat[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_qtr_mat[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_qtr_mat[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector: torch.Tensor) -> torch.Tensor:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import ( ...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n = 5 n = 5
rots = torch.rand((batch_size, n, 3, 3)) rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3)) trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2)) angles = torch.rand((batch_size, n, 7, 2))
...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase): ...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2) torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase): ...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row[..., 3] = 1 bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2) transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu() transforms_repro = out.to_tensor_4x4().cpu()
self.assertTrue( self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps) torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase): ...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3)) rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3)) trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase): ...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res, 8)) affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids) out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt) jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......
...@@ -20,7 +20,10 @@ import unittest ...@@ -20,7 +20,10 @@ import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
torsion_angle_loss, torsion_angle_loss,
...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed(): ...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import haiku as hk import haiku as hk
def affine_vector_to_4x4(affine):
r = Rigid.from_tensor_7(affine)
return r.to_tensor_4x4()
class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self): def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase): ...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt = torch.rand((batch_size, n_frames, 3, 3)) rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3)) trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3)) trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
t_gt = T(rots_gt, trans_gt) t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float() frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float() positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10 length_scale = 10
...@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase): ...@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["traj"] = affine_vector_to_4x4(value["traj"]) batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm}) out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase): ...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f = hk.transform(run_tm_loss) f = hk.transform(run_tm_loss)
np.random.seed(42)
n_res = consts.n_res n_res = consts.n_res
representations = { representations = {
...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase): ...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["structure_module"]["final_affines"] = affine_vector_to_4x4( batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
value["structure_module"]["final_affines"]
)
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"]) logits = model.aux_heads.tm(representations["pair"])
......
...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase): ...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3)) print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -31,8 +31,8 @@ from openfold.model.structure_module import ( ...@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f) out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4)) self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase): ...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans) r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp
) )
shape_before = s.shape shape_before = s.shape
s = ipa(s, z, t, mask) s = ipa(s, z, r, mask)
self.assertTrue(s.shape == shape_before) self.assertTrue(s.shape == shape_before)
...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids) quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats sample_affine = quats
......
...@@ -13,11 +13,24 @@ ...@@ -13,11 +13,24 @@
# limitations under the License. # limitations under the License.
import math import math
import numpy as np
import torch import torch
import unittest import unittest
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor( X_90_ROT = torch.tensor(
...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor( ...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self): def test_rigid_from_3_points_shape(self):
batch_size = 2 batch_size = 2
n_res = 5 n_res = 5
...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase): ...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2 = torch.rand((batch_size, n_res, 3)) x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3)) x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3) r = Rigid.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3)) self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2)) self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self): def test_rigid_from_4x4(self):
batch_size = 2 batch_size = 2
transf = [ transf = [
[1, 0, 0, 1], [1, 0, 0, 1],
...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase): ...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf = torch.stack([transf for _ in range(batch_size)], dim=0) transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf) r = Rigid.from_tensor_4x4(transf)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0))) self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0))) self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self): def test_rigid_shape(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
self.assertTrue(transf.shape == (batch_size, n)) self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self): def test_rigid_cat(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
transf_concat = T.concat([transf, transf], dim=0) transf_cat = Rigid.cat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3)) transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
transf_concat = T.concat([transf, transf], dim=1) self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3)) transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots)) self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
def test_T_compose(self): self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0] trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1] trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1)) r = Rotation(rot_mats=X_90_ROT)
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2)) t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2) t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3))) self.assertTrue(
self.assertTrue(torch.all(t3.trans == 0)) torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_T_apply(self): def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0) rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1]) trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0) trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30) x = torch.arange(30)
x = torch.stack([x, x], dim=0) x = torch.stack([x, x], dim=0)
...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase): ...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps = 1e-07 eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps)) self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self): def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15) x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30) l = torch.nn.Linear(15, 30)
...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase): ...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened = x_flat[i:j] chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened)) self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment