Commit e3daf724 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Overhaul transformation code for better parity w/ AlphaFold

parent 1f709b0d
...@@ -89,8 +89,8 @@ config = mlc.ConfigDict( ...@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists": [NUM_RES, None], "atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None], "atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None], "atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES], "backbone_rigid_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None], "backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES], "bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None], "chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None], "chi_mask": [NUM_RES, None],
...@@ -126,8 +126,8 @@ config = mlc.ConfigDict( ...@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos": [ "template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES], "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [ "template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None, NUM_TEMPLATES, NUM_RES, None, None,
], ],
"template_mask": [NUM_TEMPLATES], "template_mask": [NUM_TEMPLATES],
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -752,7 +752,7 @@ def make_atom14_positions(protein): ...@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return protein return protein
def atom37_to_frames(protein): def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"] aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"] all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"] all_atom_mask = protein["all_atom_mask"]
...@@ -810,11 +810,11 @@ def atom37_to_frames(protein): ...@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims=len(all_atom_positions.shape[:-2]), no_batch_dims=len(all_atom_positions.shape[:-2]),
) )
gt_frames = T.from_3_points( gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :], p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8, eps=eps,
) )
group_exists = batched_gather( group_exists = batched_gather(
...@@ -836,8 +836,9 @@ def atom37_to_frames(protein): ...@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(T(rots, None)) gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
...@@ -871,10 +872,15 @@ def atom37_to_frames(protein): ...@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None)) residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_4x4() gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists protein["rigidgroups_gt_exists"] = gt_exists
...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles( ...@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim=-1, dim=-1,
) )
torsion_frames = T.from_3_points( torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :], torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :], torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :], torsions_atom_pos[..., 0, :],
...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles( ...@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein): def get_backbone_frames(protein):
# TODO: Verify that this is correct # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, : ..., 0, :, :
] ]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0] protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein return protein
......
...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool: ...@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
...@@ -475,7 +477,7 @@ def get_atom_coords( ...@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
if zero_center: if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool) binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0) translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec all_atom_positions[binary_mask] -= translation_vec
......
...@@ -503,10 +503,13 @@ def _get_atom_positions( ...@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float, max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
) )
all_atom_positions, all_atom_mask = coords_with_mask all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances( _check_residue_distances(
...@@ -523,6 +526,7 @@ def _extract_template_features( ...@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence: str, query_sequence: str,
template_chain_id: str, template_chain_id: str,
kalign_binary_path: str, kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]: ) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query. """Parses atom positions in the target structure and aligns with the query.
...@@ -607,7 +611,10 @@ def _extract_template_features( ...@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless # Essentially set to infinity - we don't want to reject templates unless
# they're really really bad. # they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions( all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0 mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
) )
except (CaDistanceError, KeyError) as ex: except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError( raise NoAtomDataInTemplateError(
...@@ -795,6 +802,7 @@ def _process_single_hit( ...@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str, kalign_binary_path: str,
strict_error_check: bool = False, strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult: ) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit.""" """Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
...@@ -856,6 +864,7 @@ def _process_single_hit( ...@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
) )
features["template_sum_probs"] = [hit.sum_probs] features["template_sum_probs"] = [hit.sum_probs]
...@@ -913,7 +922,6 @@ class TemplateSearchResult: ...@@ -913,7 +922,6 @@ class TemplateSearchResult:
class TemplateHitFeaturizer: class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features.""" """A class for turning hhr hits to template features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer: ...@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path: Optional[str] = None, obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False, strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None, _shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
): ):
"""Initializes the Template Search. """Initializes the Template Search.
...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer: ...@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates( def get_templates(
self, self,
...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer: ...@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path, kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
) )
if result.error: if result.error:
......
...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
device=x.device, device=x.device,
requires_grad=False,
) )
# [*, N, C_m] # [*, N, C_m]
......
...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import ( ...@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import ( from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
) )
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
permute_final_dims, permute_final_dims,
...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module): ...@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self, self,
s: torch.Tensor, s: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
t: T, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module): ...@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
t: r:
[*, N_res] affine transformation object [*, N_res] transformation object
mask: mask:
[*, N_res] mask [*, N_res] mask
Returns: Returns:
...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module): ...@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3] # [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1) q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts) q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_q, 3]
q_pts = q_pts.view( q_pts = q_pts.view(
...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module): ...@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3] # [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1) kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = t[..., None].apply(kv_pts) kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3] # [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module): ...@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt) o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims( o_pt_norm = flatten_final_dims(
...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module): ...@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class BackboneUpdate(nn.Module): class BackboneUpdate(nn.Module):
""" """
Implements Algorithm 23. Implements part of Algorithm 23.
""" """
def __init__(self, c_s): def __init__(self, c_s):
...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module): ...@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self.linear = Linear(self.c_s, 6, init="final") self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s): def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
Returns: Returns:
[*, N_res] affine transformation object [*, N_res, 6] update vector
""" """
# [*, 6] # [*, 6]
params = self.linear(s) update = self.linear(s)
# [*, 3]
quats, trans = params[..., :3], params[..., 3:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom[..., None]
# [*, 3, 3] return update
rots = quat_to_rot(quats)
return T(rots, trans)
class StructureModuleTransitionLayer(nn.Module): class StructureModuleTransitionLayer(nn.Module):
...@@ -592,7 +573,7 @@ class StructureModule(nn.Module): ...@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self, self,
s, s,
z, z,
f, aatype,
mask=None, mask=None,
): ):
""" """
...@@ -601,7 +582,7 @@ class StructureModule(nn.Module): ...@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation [*, N_res, C_s] single representation
z: z:
[*, N_res, N_res, C_z] pair representation [*, N_res, N_res, C_z] pair representation
f: aatype:
[*, N_res] amino acid indices [*, N_res] amino acid indices
mask: mask:
Optional [*, N_res] sequence mask Optional [*, N_res] sequence mask
...@@ -623,44 +604,67 @@ class StructureModule(nn.Module): ...@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s = self.linear_in(s) s = self.linear_in(s)
# [*, N] # [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training) rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = [] outputs = []
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, t, mask) s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
s = self.transition(s) s = self.transition(s)
# [*, N] # [*, N]
t = t.compose(self.bb_update(s)) rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2] # [*, N, 7, 2]
unnormalized_a, a = self.angle_resnet(s, s_initial) unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames( all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), backb_to_global,
a, angles,
f, aatype,
) )
pred_xyz = self.frames_and_literature_positions_to_atom14_pos( pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global, all_frames_to_global,
f, aatype,
) )
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = { preds = {
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(), "frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_4x4(), "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_a, "unnormalized_angles": unnormalized_angles,
"angles": a, "angles": angles,
"positions": pred_xyz, "positions": pred_xyz,
} }
outputs.append(preds) outputs.append(preds)
if i < (self.no_blocks - 1): if i < (self.no_blocks - 1):
t = t.stop_rot_gradient() rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs) outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s outputs["single"] = s
...@@ -673,38 +677,42 @@ class StructureModule(nn.Module): ...@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.group_idx is None: if self.group_idx is None:
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
device=device, device=device,
requires_grad=False,
) )
if self.atom_mask is None: if self.atom_mask is None:
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
if self.lit_positions is None: if self.lit_positions is None:
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames) return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N] self, r, f # [*, N, 8] # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device) self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
t, r,
f, f,
self.default_frames, self.default_frames,
self.group_idx, self.group_idx,
......
...@@ -22,7 +22,7 @@ from typing import Dict ...@@ -22,7 +22,7 @@ from typing import Dict
from openfold.np import protein from openfold.np import protein
import openfold.np.residue_constants as rc import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
batched_gather, batched_gather,
one_hot, one_hot,
...@@ -124,18 +124,16 @@ def build_template_pair_feat( ...@@ -124,18 +124,16 @@ def build_template_pair_feat(
) )
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision rigids = Rigid.make_transform_from_reference(
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :], n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :], ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :], c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps, eps=eps,
) )
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
points = affines.get_trans()[..., None, :, :] inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = ( template_mask = (
...@@ -144,7 +142,7 @@ def build_template_pair_feat( ...@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = affine_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch): ...@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames( def torsion_angles_to_frames(
t: T, r: Rigid,
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
...@@ -176,13 +174,15 @@ def torsion_angles_to_frames( ...@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e. # [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and # One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix # One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4) default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1 bb_rot[..., 1] = 1
# [*, N, 8, 2] # [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3] # [*, N, 8, 3, 3]
# Produces rotation matrices of the form: # Produces rotation matrices of the form:
...@@ -194,15 +194,15 @@ def torsion_angles_to_frames( ...@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
all_rots = alpha.new_zeros(default_t.rots.shape) all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None) all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_t.compose(all_rots) all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -213,7 +213,7 @@ def torsion_angles_to_frames( ...@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat( all_frames_to_bb = Rigid.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1), chi2_frame_to_bb.unsqueeze(-1),
...@@ -223,13 +223,13 @@ def torsion_angles_to_frames( ...@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim=-1, dim=-1,
) )
all_frames_to_global = t[..., None].compose(all_frames_to_bb) all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
t: T, r: Rigid,
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames,
group_idx, group_idx,
...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
) )
# [*, N, 14, 8] # [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14] # [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
......
...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
from openfold.utils.affine_utils import T from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -74,8 +74,8 @@ def torsion_angle_loss( ...@@ -74,8 +74,8 @@ def torsion_angle_loss(
def compute_fape( def compute_fape(
pred_frames: T, pred_frames: Rigid,
target_frames: T, target_frames: Rigid,
frames_mask: torch.Tensor, frames_mask: torch.Tensor,
pred_positions: torch.Tensor, pred_positions: torch.Tensor,
target_positions: torch.Tensor, target_positions: torch.Tensor,
...@@ -111,7 +111,7 @@ def compute_fape( ...@@ -111,7 +111,7 @@ def compute_fape(
# ) # )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# #
# ("roughly" because eps is necessarily duplicated in the latter # ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1) normed_error = torch.sum(normed_error, dim=-1)
normed_error = ( normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
...@@ -123,8 +123,8 @@ def compute_fape( ...@@ -123,8 +123,8 @@ def compute_fape(
def backbone_loss( def backbone_loss(
backbone_affine_tensor: torch.Tensor, backbone_rigid_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor, backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor, traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0, clamp_distance: float = 10.0,
...@@ -132,16 +132,27 @@ def backbone_loss( ...@@ -132,16 +132,27 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_aff = T.from_tensor(traj) pred_aff = Rigid.from_tensor_7(traj)
gt_aff = T.from_tensor(backbone_affine_tensor) pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -150,10 +161,10 @@ def backbone_loss( ...@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[None], gt_aff[None],
backbone_affine_mask[None], backbone_rigid_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[None], backbone_rigid_mask[None],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -193,9 +204,9 @@ def sidechain_loss( ...@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames = sidechain_frames[-1] sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4] batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames) sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames) renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1] sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
...@@ -422,7 +433,7 @@ def distogram_loss( ...@@ -422,7 +433,7 @@ def distogram_loss(
device=logits.device, device=logits.device,
) )
boundaries = boundaries ** 2 boundaries = boundaries ** 2
dists = torch.sum( dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1, dim=-1,
...@@ -550,8 +561,8 @@ def compute_tm( ...@@ -550,8 +561,8 @@ def compute_tm(
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
backbone_affine_tensor, backbone_rigid_tensor,
backbone_affine_mask, backbone_rigid_mask,
resolution, resolution,
max_bin=31, max_bin=31,
no_bins=64, no_bins=64,
...@@ -560,16 +571,17 @@ def tm_loss( ...@@ -560,16 +571,17 @@ def tm_loss(
eps=1e-8, eps=1e-8,
**kwargs, **kwargs,
): ):
pred_affine = T.from_4x4(final_affine_tensor) pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine): def _points(affine):
pts = affine.get_trans()[..., None, :, :] pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts) return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum( sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1 (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
) )
sq_diff = sq_diff.detach() sq_diff = sq_diff.detach()
boundaries = torch.linspace( boundaries = torch.linspace(
...@@ -583,7 +595,7 @@ def tm_loss( ...@@ -583,7 +595,7 @@ def tm_loss(
) )
square_mask = ( square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :] backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
) )
loss = torch.sum(errors * square_mask, dim=-1) loss = torch.sum(errors * square_mask, dim=-1)
...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
), ),
} }
cum_loss = 0 cum_loss = 0.
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
......
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import ( ...@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n = 5 n = 5
rots = torch.rand((batch_size, n, 3, 3)) rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3)) trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2)) angles = torch.rand((batch_size, n, 7, 2))
...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase): ...@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2) torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase): ...@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row[..., 3] = 1 bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2) transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu() transforms_repro = out.to_tensor_4x4().cpu()
self.assertTrue( self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps) torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase): ...@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3)) rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3)) trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans) ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase): ...@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res, 8)) affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids) out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt) jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......
...@@ -20,7 +20,10 @@ import unittest ...@@ -20,7 +20,10 @@ import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
torsion_angle_loss, torsion_angle_loss,
...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed(): ...@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import haiku as hk import haiku as hk
def affine_vector_to_4x4(affine):
r = Rigid.from_tensor_7(affine)
return r.to_tensor_4x4()
class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self): def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase): ...@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt = torch.rand((batch_size, n_frames, 3, 3)) rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3)) trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3)) trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
t_gt = T(rots_gt, trans_gt) t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float() frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float() positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10 length_scale = 10
...@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase): ...@@ -686,11 +694,11 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["traj"] = affine_vector_to_4x4(value["traj"]) batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm}) out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase): ...@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f = hk.transform(run_tm_loss) f = hk.transform(run_tm_loss)
np.random.seed(42)
n_res = consts.n_res n_res = consts.n_res
representations = { representations = {
...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase): ...@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray) value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4( batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"] batch["backbone_affine_tensor"]
) )
value["structure_module"]["final_affines"] = affine_vector_to_4x4( batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
value["structure_module"]["final_affines"]
)
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"]) logits = model.aux_heads.tm(representations["pair"])
......
...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase): ...@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3)) print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -31,8 +31,8 @@ from openfold.model.structure_module import ( ...@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f) out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4)) self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase): ...@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans) r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp
) )
shape_before = s.shape shape_before = s.shape
s = ipa(s, z, t, mask) s = ipa(s, z, r, mask)
self.assertTrue(s.shape == shape_before) self.assertTrue(s.shape == shape_before)
...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines) rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids) quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda()) transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats sample_affine = quats
......
...@@ -13,11 +13,24 @@ ...@@ -13,11 +13,24 @@
# limitations under the License. # limitations under the License.
import math import math
import numpy as np
import torch import torch
import unittest import unittest
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor( X_90_ROT = torch.tensor(
...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor( ...@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self): def test_rigid_from_3_points_shape(self):
batch_size = 2 batch_size = 2
n_res = 5 n_res = 5
...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase): ...@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2 = torch.rand((batch_size, n_res, 3)) x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3)) x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3) r = Rigid.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3)) self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2)) self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self): def test_rigid_from_4x4(self):
batch_size = 2 batch_size = 2
transf = [ transf = [
[1, 0, 0, 1], [1, 0, 0, 1],
...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase): ...@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf = torch.stack([transf for _ in range(batch_size)], dim=0) transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf) r = Rigid.from_tensor_4x4(transf)
rot, tra = t.rots, t.trans rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0))) self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0))) self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self): def test_rigid_shape(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
self.assertTrue(transf.shape == (batch_size, n)) self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self): def test_rigid_cat(self):
batch_size = 2 batch_size = 2
n = 5 n = 5
transf = T( transf = Rigid(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3)) Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
) )
transf_concat = T.concat([transf, transf], dim=0) transf_cat = Rigid.cat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3)) transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
transf_concat = T.concat([transf, transf], dim=1) self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3)) transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots)) self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
def test_T_compose(self): self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0] trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1] trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1)) r = Rotation(rot_mats=X_90_ROT)
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2)) t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2) t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3))) self.assertTrue(
self.assertTrue(torch.all(t3.trans == 0)) torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_T_apply(self): def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0) rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1]) trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0) trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans) t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30) x = torch.arange(30)
x = torch.stack([x, x], dim=0) x = torch.stack([x, x], dim=0)
...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase): ...@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps = 1e-07 eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps)) self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self): def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15) x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30) l = torch.nn.Linear(15, 30)
...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase): ...@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened = x_flat[i:j] chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened)) self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment