"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "59202300f1a774068608caff0351cc084fe79e7c"
Commit e3daf724 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Overhaul transformation code for better parity w/ AlphaFold

parent 1f709b0d
......@@ -89,8 +89,8 @@ config = mlc.ConfigDict(
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
......@@ -126,8 +126,8 @@ config = mlc.ConfigDict(
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
......
......@@ -22,7 +22,7 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -752,7 +752,7 @@ def make_atom14_positions(protein):
return protein
def atom37_to_frames(protein):
def atom37_to_frames(protein, eps=1e-8):
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
......@@ -810,11 +810,11 @@ def atom37_to_frames(protein):
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
gt_frames = Rigid.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8,
eps=eps,
)
group_exists = batched_gather(
......@@ -836,8 +836,9 @@ def atom37_to_frames(protein):
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
rots = Rotation(rot_mats=rots)
gt_frames = gt_frames.compose(T(rots, None))
gt_frames = gt_frames.compose(Rigid(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
......@@ -871,10 +872,15 @@ def atom37_to_frames(protein):
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
residx_rigidgroup_ambiguity_rot = Rotation(
rot_mats=residx_rigidgroup_ambiguity_rot
)
alt_gt_frames = gt_frames.compose(
Rigid(residx_rigidgroup_ambiguity_rot, None)
)
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
gt_frames_tensor = gt_frames.to_tensor_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists
......@@ -1028,7 +1034,7 @@ def atom37_to_torsion_angles(
dim=-1,
)
torsion_frames = T.from_3_points(
torsion_frames = Rigid.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
......@@ -1082,11 +1088,11 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein):
# TODO: Verify that this is correct
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, :
]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein
......
......@@ -430,7 +430,9 @@ def _is_set(data: str) -> bool:
def get_atom_coords(
mmcif_object: MmcifObject, chain_id: str, zero_center: bool = True
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
......@@ -475,7 +477,7 @@ def get_atom_coords(
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
if zero_center:
if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec
......
......@@ -503,10 +503,13 @@ def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
......@@ -523,6 +526,7 @@ def _extract_template_features(
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
......@@ -607,7 +611,10 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0
mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
......@@ -795,6 +802,7 @@ def _process_single_hit(
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
......@@ -856,6 +864,7 @@ def _process_single_hit(
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
)
features["template_sum_probs"] = [hit.sum_probs]
......@@ -913,7 +922,6 @@ class TemplateSearchResult:
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
def __init__(
self,
mmcif_dir: str,
......@@ -924,6 +932,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
):
"""Initializes the Template Search.
......@@ -982,6 +991,7 @@ class TemplateHitFeaturizer:
self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates(
self,
......@@ -1057,6 +1067,7 @@ class TemplateHitFeaturizer:
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
)
if result.error:
......
......@@ -198,6 +198,7 @@ class RecyclingEmbedder(nn.Module):
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
......
......@@ -25,11 +25,11 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
dict_multimap,
permute_final_dims,
......@@ -225,7 +225,7 @@ class InvariantPointAttention(nn.Module):
self,
s: torch.Tensor,
z: torch.Tensor,
t: T,
r: Rigid,
mask: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -234,8 +234,8 @@ class InvariantPointAttention(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t:
[*, N_res] affine transformation object
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
......@@ -264,7 +264,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
......@@ -277,7 +277,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = t[..., None].apply(kv_pts)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
......@@ -349,7 +349,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt)
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
......@@ -377,7 +377,7 @@ class InvariantPointAttention(nn.Module):
class BackboneUpdate(nn.Module):
"""
Implements Algorithm 23.
Implements part of Algorithm 23.
"""
def __init__(self, c_s):
......@@ -392,36 +392,17 @@ class BackboneUpdate(nn.Module):
self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s):
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res] affine transformation object
[*, N_res, 6] update vector
"""
# [*, 6]
params = self.linear(s)
# [*, 3]
quats, trans = params[..., :3], params[..., 3:]
# [*]
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom[..., None]
update = self.linear(s)
# [*, 3, 3]
rots = quat_to_rot(quats)
return T(rots, trans)
return update
class StructureModuleTransitionLayer(nn.Module):
......@@ -592,7 +573,7 @@ class StructureModule(nn.Module):
self,
s,
z,
f,
aatype,
mask=None,
):
"""
......@@ -601,7 +582,7 @@ class StructureModule(nn.Module):
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f:
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
......@@ -623,44 +604,67 @@ class StructureModule(nn.Module):
s = self.linear_in(s)
# [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, t, mask)
s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
t = t.compose(self.bb_update(s))
rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2]
unnormalized_a, a = self.angle_resnet(s, s_initial)
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor),
a,
f,
backb_to_global,
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
f,
aatype,
)
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = {
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global.to_4x4(),
"unnormalized_angles": unnormalized_a,
"angles": a,
"frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
}
outputs.append(preds)
if i < (self.no_blocks - 1):
t = t.stop_rot_gradient()
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
......@@ -673,38 +677,42 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames)
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N]
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device)
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
t,
r,
f,
self.default_frames,
self.group_idx,
......
......@@ -22,7 +22,7 @@ from typing import Dict
from openfold.np import protein
import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
batched_gather,
one_hot,
......@@ -124,18 +124,16 @@ def build_template_pair_feat(
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision
affines = T.make_transform_from_reference(
rigids = Rigid.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
......@@ -144,7 +142,7 @@ def build_template_pair_feat(
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = affine_vec * inv_distance_scalar[..., None]
unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
......@@ -165,7 +163,7 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames(
t: T,
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
......@@ -176,13 +174,15 @@ def torsion_angles_to_frames(
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4)
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
......@@ -194,15 +194,15 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_t.rots.shape)
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None)
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_t.compose(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -213,7 +213,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat(
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
......@@ -223,13 +223,13 @@ def torsion_angles_to_frames(
dim=-1,
)
all_frames_to_global = t[..., None].compose(all_frames_to_bb)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
t: T,
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
......@@ -249,7 +249,7 @@ def frames_and_literature_positions_to_atom14_pos(
)
# [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
......
......@@ -24,7 +24,7 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.affine_utils import T
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -74,8 +74,8 @@ def torsion_angle_loss(
def compute_fape(
pred_frames: T,
target_frames: T,
pred_frames: Rigid,
target_frames: Rigid,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
......@@ -111,7 +111,7 @@ def compute_fape(
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
......@@ -123,8 +123,8 @@ def compute_fape(
def backbone_loss(
backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor,
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
......@@ -132,16 +132,27 @@ def backbone_loss(
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(backbone_affine_tensor)
pred_aff = Rigid.from_tensor_7(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_affine_mask[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_affine_mask[None],
backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
......@@ -150,10 +161,10 @@ def backbone_loss(
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_affine_mask[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_affine_mask[None],
backbone_rigid_mask[None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
......@@ -193,9 +204,9 @@ def sidechain_loss(
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = T.from_4x4(sidechain_frames)
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = T.from_4x4(renamed_gt_frames)
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
......@@ -550,8 +561,8 @@ def compute_tm(
def tm_loss(
logits,
final_affine_tensor,
backbone_affine_tensor,
backbone_affine_mask,
backbone_rigid_tensor,
backbone_rigid_mask,
resolution,
max_bin=31,
no_bins=64,
......@@ -560,16 +571,17 @@ def tm_loss(
eps=1e-8,
**kwargs,
):
pred_affine = T.from_4x4(final_affine_tensor)
backbone_affine = T.from_4x4(backbone_affine_tensor)
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_affine)) ** 2, dim=-1
(_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
......@@ -583,7 +595,7 @@ def tm_loss(
)
square_mask = (
backbone_affine_mask[..., None] * backbone_affine_mask[..., None, :]
backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
)
loss = torch.sum(errors * square_mask, dim=-1)
......@@ -1503,11 +1515,12 @@ class AlphaFoldLoss(nn.Module):
),
}
cum_loss = 0
cum_loss = 0.
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
......
../../../../openfold/resources/stereo_chemical_props.txt
\ No newline at end of file
......@@ -23,8 +23,8 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -187,7 +187,7 @@ class TestFeats(unittest.TestCase):
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans)
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2))
......@@ -222,7 +222,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
......@@ -250,7 +252,7 @@ class TestFeats(unittest.TestCase):
bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu()
transforms_repro = out.to_tensor_4x4().cpu()
self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
......@@ -262,7 +264,7 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans)
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
......@@ -293,7 +295,9 @@ class TestFeats(unittest.TestCase):
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
......
......@@ -20,7 +20,10 @@ import unittest
import ml_collections as mlc
from openfold.data import data_transforms
from openfold.utils.affine_utils import T, affine_vector_to_4x4
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats
from openfold.utils.loss import (
torsion_angle_loss,
......@@ -55,6 +58,11 @@ if compare_utils.alphafold_is_installed():
import haiku as hk
def affine_vector_to_4x4(affine):
r = Rigid.from_tensor_7(affine)
return r.to_tensor_4x4()
class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size
......@@ -77,8 +85,8 @@ class TestLoss(unittest.TestCase):
rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans)
t_gt = T(rots_gt, trans_gt)
t = Rigid(Rotation(rot_mats=rots), trans)
t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10
......@@ -686,10 +694,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["traj"] = affine_vector_to_4x4(value["traj"])
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
......@@ -807,6 +815,8 @@ class TestLoss(unittest.TestCase):
f = hk.transform(run_tm_loss)
np.random.seed(42)
n_res = consts.n_res
representations = {
......@@ -839,12 +849,10 @@ class TestLoss(unittest.TestCase):
batch = tree_map(to_tensor, batch, np.ndarray)
value = tree_map(to_tensor, value, np.ndarray)
batch["backbone_affine_tensor"] = affine_vector_to_4x4(
batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
batch["backbone_affine_tensor"]
)
value["structure_module"]["final_affines"] = affine_vector_to_4x4(
value["structure_module"]["final_affines"]
)
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
model = compare_utils.get_global_pretrained_openfold()
logits = model.aux_heads.tm(representations["pair"])
......
......@@ -130,4 +130,5 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
......@@ -31,8 +31,8 @@ from openfold.model.structure_module import (
AngleResnet,
InvariantPointAttention,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -89,7 +89,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -177,23 +177,6 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self):
c_m = 13
......@@ -210,17 +193,18 @@ class TestInvariantPointAttention(unittest.TestCase):
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3))
rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans)
r = Rigid(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
)
shape_before = s.shape
s = ipa(s, z, t, mask)
s = ipa(s, z, r, mask)
self.assertTrue(s.shape == shape_before)
......@@ -253,7 +237,9 @@ class TestInvariantPointAttention(unittest.TestCase):
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(torch.as_tensor(affines).float().cuda())
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
......
......@@ -13,11 +13,24 @@
# limitations under the License.
import math
import numpy as np
import torch
import unittest
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor(
......@@ -38,7 +51,7 @@ X_NEG_90_ROT = torch.tensor(
class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self):
def test_rigid_from_3_points_shape(self):
batch_size = 2
n_res = 5
......@@ -46,14 +59,14 @@ class TestUtils(unittest.TestCase):
x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3)
r = Rigid.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self):
def test_rigid_from_4x4(self):
batch_size = 2
transf = [
[1, 0, 0, 1],
......@@ -68,58 +81,79 @@ class TestUtils(unittest.TestCase):
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
t = T.from_4x4(transf)
r = Rigid.from_tensor_4x4(transf)
rot, tra = t.rots, t.trans
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self):
def test_rigid_shape(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self):
def test_rigid_cat(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)), torch.rand((batch_size, n, 3))
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
transf_concat = T.concat([transf, transf], dim=0)
transf_cat = Rigid.cat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3))
transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
transf_concat = T.concat([transf, transf], dim=1)
self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3))
transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
def test_T_compose(self):
self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1))
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2))
r = Rotation(rot_mats=X_90_ROT)
t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3)))
self.assertTrue(torch.all(t3.trans == 0))
self.assertTrue(
torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_T_apply(self):
def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans)
t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
......@@ -141,6 +175,12 @@ class TestUtils(unittest.TestCase):
eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30)
......@@ -180,3 +220,33 @@ class TestUtils(unittest.TestCase):
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment