Commit e3daf724 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Overhaul transformation code for better parity w/ AlphaFold

parent 1f709b0d
......@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
......@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
......
......@@ -22,7 +22,7 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return protein
def atom37_to_frames(protein):
def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
......@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8,
eps=eps,
)
group_exists = batched_gather(
......@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(T(rots, None))
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
......@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists
......@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim=-1,
)
torsion_frames = T.from_3_points(
torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
......@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein):
# TODO: Verify that this is correct
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, :
]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein
......
......@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
......@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
if zero_center:
if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec
......
......@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
......@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
......@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0
mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
......@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
......@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
)
features["template_sum_probs"] = [hit.sum_probs]
......@@ -913,7 +922,6 @@ class TemplateSearchResult:
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
def __init__(
self,
mmcif_dir: str,
......@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
):
"""Initializes the Template Search.
......@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates(
self,
......@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
)
if result.error:
......
......@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
......
......@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
dict_multimap,
permute_final_dims,
......@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self,
s: torch.Tensor,
z: torch.Tensor,
t: T,
r: Rigid,
mask: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t:
[*, N_res] affine transformation object
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
......@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
......@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = t[..., None].apply(kv_pts)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
......@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt)
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
......@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class BackboneUpdate(nn.Module):
"""
Implements Algorithm 23.
Implements part of Algorithm 23.
"""
def __init__(self, c_s):
......@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s):
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res] affine transformation object
[*, N_res, 6] update vector
"""
# [*, 6]
params = self.linear(s)
# [*, 3]
quats, trans = params[..., :3], params[..., 3:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom[..., None]
update = self.linear(s)
# [*, 3, 3]
rots = quat_to_rot(quats)
return T(rots, trans)
return update
class StructureModuleTransitionLayer(nn.Module):
......@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self,
s,
z,
f,
aatype,
mask=None,
):
"""
......@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f:
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
......@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s = self.linear_in(s)
# [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, t, mask)
s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
t = t.compose(self.bb_update(s))
rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2]
unnormalized_a, a = self.angle_resnet(s, s_initial)
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor),
a,
f,
backb_to_global,
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
f,
aatype,
)
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = {
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global.to_4x4(),
"unnormalized_angles": unnormalized_a,
"angles": a,
"frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
}
outputs.append(preds)
if i < (self.no_blocks - 1):
t = t.stop_rot_gradient()
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames)
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N]
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device)
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
t,
r,
f,
self.default_frames,
self.group_idx,
......
......@@ -22,7 +22,7 @@ from typing import Dict
from openfold.np import protein
import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
batched_gather,
one_hot,
......@@ -124,18 +124,16 @@ def build_template_pair_feat(
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision
affines = T.make_transform_from_reference(
rigids = Rigid.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
......@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = affine_vec * inv_distance_scalar[..., None]
unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
......@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames(
t: T,
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
......@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4)
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
......@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_t.rots.shape)
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None)
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_t.compose(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat(
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
......@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim=-1,
)
all_frames_to_global = t[..., None].compose(all_frames_to_bb)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
t: T,
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
......@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
......
......@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -74,8 +74,8 @@ def torsion_angle_loss(
def compute_fape(
pred_frames: T,
target_frames: T,
pred_frames: Rigid,
target_frames: Rigid,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
......@@ -111,7 +111,7 @@ def compute_fape(
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
......@@ -123,8 +123,8 @@ def compute_fape(
def backbone_loss(
backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor,
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
......@@ -132,16 +132,27 @@ def backbone_loss(
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(backbone_affine_tensor)
pred_aff = Rigid.from_tensor_7(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_affine_mask[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_affine_mask[None],
backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
......@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_affine_mask[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_affine_mask[None],
backbone_rigid_mask[None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
......@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames)
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
......@@ -422,7 +433,7 @@ def distogram_loss(
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
......@@ -550,8 +561,8 @@ def compute_tm(
def tm_loss(
logits,
final_affine_tensor,
backbone_affine_tensor,
backbone_affine_mask,
backbone_rigid_tensor,
backbone_rigid_mask,
resolution,
max_bin=31,
no_bins=64,
......@@ -560,16 +571,17 @@ def tm_loss(
eps=1e-8,
**kwargs,
):
pred_affine = T.from_4x4(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor)
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1
(_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
......@@ -583,7 +595,7 @@ def tm_loss(
)
square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :]
backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
)
loss = torch.sum(errors * square_mask, dim=-1)
......@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
),
}
cum_loss = 0
cum_loss = 0.
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
......
......@@ -106,53 +106,791 @@ def rot_vec_mul(
dim=-1,
)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
return rots
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros(
(*batch_dims, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat1[..., :, None, None] *
quat2[..., None, :, None],
dim=(-3, -2)
)
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat[..., :, None, None] *
vec[..., None, :, None],
dim=(-3, -2)
)
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
class T:
def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A class representing an affine transformation. Essentially a wrapper
around two torch tensors: a [*, 3, 3] rotation and a [*, 3]
translation. Designed to behave approximately like a single torch
tensor with the shape of the shared dimensions of its component parts.
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def __init__(self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if((rot_mats is None and quats is None) or
(rot_mats is not None and quats is not None)):
raise ValueError("Exactly one input argument must be specified")
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
(quats is not None and quats.shape[-1] != 4)):
raise ValueError(
"Incorrectly shaped rotation matrix or quaternion"
)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if(fmt == "rot_mat"):
rot_mats = identity_rot_mats(
shape, dtype, device, requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif(fmt == "quat"):
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if(self._rot_mats is not None):
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif(self._quats is not None):
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if(self._rot_mats is not None):
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.dtype
elif(self._quats is not None):
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.device
elif(self._quats is not None):
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if(self._rot_mats is not None):
return self._rot_mats.requires_grad
elif(self._quats is not None):
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if(rot_mats is None):
if(self._quats is None):
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if(quats is None):
if(self._rot_mats is None):
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if(self._rot_mats is not None):
return self._rot_mats
elif(self._quats is not None):
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True
) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=invert_rot_mat(self._rot_mats),
quats=None
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if(self._rot_mats is not None):
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if(self._rot_mats is not None):
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = torch.stack(
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.cuda(),
normalize_quats=False
)
else:
raise ValueError("Both rotations are None")
def to(self,
device: Optional[torch.device],
dtype: Optional[torch.dtype]
) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
def __init__(self,
rots: torch.Tensor,
trans: torch.Tensor
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
self.rots = rots
self.trans = trans
if self.rots is None and self.trans is None:
raise ValueError("Only one of rots and trans can be None")
elif self.rots is None:
self.rots = T._identity_rot(
self.trans.shape[:-1],
self.trans.dtype,
self.trans.device,
self.trans.requires_grad,
# (we need device, dtype, etc. from at least one input)
batch_dims, dtype, device, requires_grad = None, None, None, None
if(trans is not None):
batch_dims = trans.shape[:-1]
dtype = trans.dtype
device = trans.device
requires_grad = trans.requires_grad
elif(rots is not None):
batch_dims = rots.shape
dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if(rots is None):
rots = Rotation.identity(
batch_dims, dtype, device, requires_grad,
)
elif self.trans is None:
self.trans = T._identity_trans(
self.rots.shape[:-2],
self.rots.dtype,
self.rots.device,
self.rots.requires_grad,
elif(trans is None):
trans = identity_trans(
batch_dims, dtype, device, requires_grad,
)
if (
self.rots.shape[-2:] != (3, 3)
or self.trans.shape[-1] != 3
or self.rots.shape[:-2] != self.trans.shape[:-1]
):
raise ValueError("Incorrectly shaped input")
if((rots.shape != trans.shape[:-1]) or
(rots.device != trans.device)):
raise ValueError("Rots and trans incompatible")
self._rots = rots
self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self,
index: Any,
) -> T:
) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
......@@ -160,11 +898,12 @@ class T:
E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3))
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2, 3, 3))
assert(indexed.trans.shape == (2, 3))
assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
......@@ -174,54 +913,45 @@ class T:
"""
if type(index) != tuple:
index = (index,)
return T(
self.rots[index + (slice(None), slice(None))],
self.trans[index + (slice(None),)],
)
def __eq__(self,
obj: T,
) -> bool:
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return bool(
torch.all(self.rots == obj.rots) and
torch.all(self.trans == obj.trans)
return Rigid(
self._rots[index],
self._trans[index + (slice(None),)],
)
def __mul__(self,
def __mul__(self,
right: torch.Tensor,
) -> T:
) -> Rigid:
"""
Pointwise right multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Pointwise left multiplication of the transformation with a tensor.
Can be used to e.g. mask the Rigid.
Args:
right: The right multiplicand
right:
The tensor multiplicand
Returns:
The product transformation
The product
"""
rots = self.rots * right[..., None, None]
trans = self.trans * right[..., None]
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
return T(rots, trans)
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
def __rmul__(self,
return Rigid(new_rots, new_trans)
def __rmul__(self,
left: torch.Tensor,
) -> T:
) -> Rigid:
"""
Pointwise left multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Reverse pointwise multiplication of the transformation with a
tensor.
Args:
left: The left multiplicand
left:
The left multiplicand
Returns:
The product transformation
The product
"""
return self.__mul__(left)
......@@ -234,45 +964,74 @@ class T:
Returns:
The shape of the transformation
"""
s = self.rots.shape[:-2]
return s if len(s) > 0 else torch.Size([1])
s = self._trans.shape[:-1]
return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
def get_rots(self):
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
Returns:
The stored rotation.
The rotation object
"""
return self.rots
return self._rots
def get_trans(self) -> torch.Tensor:
"""
Getter for the translation.
Returns:
The stored translation.
The stored translation
"""
return self.trans
return self._trans
def compose(self,
t: T,
) -> T:
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
) -> Rigid:
"""
Composes the transformation with another.
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
t: The inner transformation.
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
rot_1, trn_1 = self.rots, self.trans
rot_2, trn_2 = t.rots, t.trans
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
rot = rot_matmul(rot_1, rot_2)
trn = rot_vec_mul(rot_1, trn_2) + trn_1
return Rigid(new_rots, new_translation)
return T(rot, trn)
def compose(self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self,
pts: torch.Tensor,
......@@ -285,9 +1044,8 @@ class T:
Returns:
The transformed points.
"""
r, t = self.rots, self.trans
rotated = rot_vec_mul(r, pts)
return rotated + t
rotated = self._rots.apply(pts)
return rotated + self._trans
def invert_apply(self,
pts: torch.Tensor
......@@ -300,99 +1058,60 @@ class T:
Returns:
The transformed points.
"""
r, t = self.rots, self.trans
pts = pts - t
return rot_vec_mul(r.transpose(-1, -2), pts)
pts = pts - self._trans
return self._rots.invert_apply(pts)
def invert(self) -> T:
def invert(self) -> Rigid:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv = self.rots.transpose(-1, -2)
trn_inv = rot_vec_mul(rot_inv, self.trans)
rot_inv = self._rots.invert()
trn_inv = rot_inv.apply(self._trans)
return T(rot_inv, -1 * trn_inv)
return Rigid(rot_inv, -1 * trn_inv)
def unsqueeze(self,
dim: int,
) -> T:
def map_tensor_fn(self,
fn: Callable[tensor.Tensor, tensor.Tensor]
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Apply a Tensor -> Tensor function to underlying translation and
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
dim: A positive or negative dimension index.
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2)
trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1)
return T(rots, trans)
@staticmethod
def _identity_rot(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
The transformed Rigid object
"""
new_rots = self._rots.map_tensor_fn(fn)
new_trans = torch.stack(
list(map(fn, torch.unbind(self._trans, dim=-1))),
dim=-1
)
rots = rots.view(*((1,) * len(shape)), 3, 3)
rots = rots.expand(*shape, -1, -1)
return rots
@staticmethod
def _identity_trans(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool
) -> torch.Tensor:
trans = torch.zeros(
(*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad
)
return trans
return Rigid(new_rots, new_trans)
@staticmethod
def identity(
shape: Tuple[int],
dtype: torch.dtype,
device: torch.device,
requires_grad: bool = True
) -> T:
def to_tensor_4x4(self) -> torch.Tensor:
"""
Constructs an identity transformation.
Converts a transformation to a homogenous transformation tensor.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
A [*, 4, 4] homogenous transformation tensor
"""
return T(
T._identity_rot(shape, dtype, device, requires_grad),
T._identity_trans(shape, dtype, device, requires_grad),
)
tensor = self._trans.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._rots.get_rot_mats()
tensor[..., :3, 3] = self._trans
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_4x4(
def from_tensor_4x4(
t: torch.Tensor
) -> T:
) -> Rigid:
"""
Constructs a transformation from a homogenous transformation
tensor.
......@@ -402,35 +1121,45 @@ class T:
Returns:
T object with shape [*]
"""
rots = t[..., :3, :3]
if(t.shape[-2:] != (4, 4)):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3]
return T(rots, trans)
return Rigid(rots, trans)
def to_4x4(self) -> torch.Tensor:
def to_tensor_7(self) -> torch.Tensor:
"""
Converts a transformation to a homogenous transformation tensor.
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
A [*, 4, 4] homogenous transformation tensor
A [*, 7] tensor representation of the transformation
"""
tensor = self.rots.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self.rots
tensor[..., :3, 3] = self.trans
tensor[..., 3, 3] = 1
tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :4] = self._rots.get_quats()
tensor[..., 4:] = self._trans
return tensor
@staticmethod
def from_tensor(t: torch.Tensor) -> T:
"""
Constructs a transformation from a homogenous transformation
tensor.
def from_tensor_7(
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
if(t.shape[-1] != 7):
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(
rot_mats=None,
quats=quats,
normalize_quats=normalize_quats
)
Args:
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return T.from_4x4(t)
return Rigid(rots, trans)
@staticmethod
def from_3_points(
......@@ -438,7 +1167,7 @@ class T:
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8
) -> T:
) -> Rigid:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
......@@ -473,13 +1202,34 @@ class T:
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
return T(rots, torch.stack(origin, dim=-1))
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod
def concat(
ts: Sequence[T],
def cat(
ts: Sequence[Rigid],
dim: int,
) -> T:
) -> Rigid:
"""
Concatenates transformations along a new dimension.
......@@ -492,57 +1242,60 @@ class T:
Returns:
A concatenated transformation object
"""
rots = torch.cat([t.rots for t in ts], dim=dim if dim >= 0 else dim - 2)
rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat(
[t.trans for t in ts], dim=dim if dim >= 0 else dim - 1
[t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
)
return T(rots, trans)
return Rigid(rots, trans)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> T:
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
Applies a Rotation -> Rotation function to the stored rotation
object.
Args:
fn: A function that takes only a tensor as its argument
fn: A function of type Rotation -> Rotation
Returns:
The transformed transformation object.
A transformation object with a transformed rotation.
"""
rots = self.rots.view(*self.rots.shape[:-2], 9)
rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1)
rots = rots.view(*rots.shape[:-1], 3, 3)
return Rigid(fn(self._rots), self._trans)
trans = torch.stack(list(map(fn, torch.unbind(self.trans, -1))), dim=-1)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
return T(rots, trans)
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def stop_rot_gradient(self) -> T:
def scale_translation(self, trans_scale_factor: float) -> Rigid:
"""
Detaches the contained rotation tensor.
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A version of the transformation with detached rotations
A transformation object with a scaled translation.
"""
return T(self.rots.detach(), self.trans)
fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def scale_translation(self, factor: int) -> T:
def stop_rot_gradient(self) -> Rigid:
"""
Scales the contained translation tensor by a constant factor.
Detaches the underlying rotation object
Returns:
A version of the transformation with scaled translations
A transformation object with detached rotations
"""
return T(self.rots, self.trans * factor)
fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
......@@ -613,87 +1366,15 @@ class T:
rots = rots.transpose(-1, -2)
translation = -1 * translation
return T(rots, translation)
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> T:
def cuda(self) -> Rigid:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return T(self.rots.cuda(), self.trans.cuda())
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_qtr_mat[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_qtr_mat[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_qtr_mat[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_qtr_mat[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_qtr_mat[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_qtr_mat[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_qtr_mat[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector: torch.Tensor) -> torch.Tensor:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
return Rigid(self._rots.cuda(), self._trans.cuda())
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
......@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans)
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2))
......@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
......@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu()
transforms_repro = out.to_tensor_4x4().cpu()
self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
......@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans)
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
......@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......
......@@ -20,7 +20,10 @@ import unittest
import ml_collections as mlc
from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats
from openfold.utils.loss import (
torsion_angle_loss,
......@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import haiku as hk
def affine_vector_to_4x4(affine):
r = Rigid.from_tensor_7(affine)
return r.to_tensor_4x4()
class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size
......@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans)
t_gt = T(rots_gt, trans_gt)
t = Rigid(Rotation(rot_mats=rots), trans)
t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10
......@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["traj"] = affine_vector_to_4x4(value["traj"])
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
......@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f = hk.transform(run_tm_loss)
np.random.seed(42)
n_res = consts.n_res
representations = {
......@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["structure_module"]["final_affines"] = affine_vector_to_4x4(
value["structure_module"]["final_affines"]
)
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"])
......
......@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
......@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet,
InvariantPointAttention,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self):
c_m = 13
......@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3))
rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans)
r = Rigid(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
)
shape_before = s.shape
s = ipa(s, z, t, mask)
s = ipa(s, z, r, mask)
self.assertTrue(s.shape == shape_before)
......@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
......
......@@ -13,11 +13,24 @@
# limitations under the License.
import math
import numpy as np
import torch
import unittest
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor(
......@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self):
def test_rigid_from_3_points_shape(self):
batch_size = 2
n_res = 5
......@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3)
r = Rigid.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self):
def test_rigid_from_4x4(self):
batch_size = 2
transf = [
[1, 0, 0, 1],
......@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf)
r = Rigid.from_tensor_4x4(transf)
rot, tra = t.rots, t.trans
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self):
def test_rigid_shape(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self):
def test_rigid_cat(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
transf_concat = T.concat([transf, transf], dim=0)
transf_cat = Rigid.cat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3))
transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
transf_concat = T.concat([transf, transf], dim=1)
self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3))
transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
def test_T_compose(self):
self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1))
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2))
r = Rotation(rot_mats=X_90_ROT)
t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3)))
self.assertTrue(torch.all(t3.trans == 0))
self.assertTrue(
torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_T_apply(self):
def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans)
t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
......@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30)
......@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment