Commit 68ba77e5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue fixing loss bugs, clean up structure module docs

parent 33941e46
...@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos( ...@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos(
): ):
# [*, N, 14, 4, 4] # [*, N, 14, 4, 4]
default_4x4 = default_frames[f,...] default_4x4 = default_frames[f, ...]
# [*, N, 14] # [*, N, 14]
group_mask = group_idx[f,...] group_mask = group_idx[f, ...]
# [N, 14, 8] # [*, N, 14, 8]
group_mask = nn.functional.one_hot( group_mask = nn.functional.one_hot(
group_mask, num_classes=default_frames.shape[-3], group_mask, num_classes=default_frames.shape[-3],
) )
...@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos( ...@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos(
lambda x: torch.sum(x, dim=-1) lambda x: torch.sum(x, dim=-1)
) )
# [N, 14, 1] # [*, N, 14, 1]
atom_mask = atom_mask[f,...].unsqueeze(-1) atom_mask = atom_mask[f,...].unsqueeze(-1)
# [N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[f,...] lit_positions = lit_positions[f, ...]
pred_positions = t_atoms_to_global.apply(lit_positions) pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions *= atom_mask pred_positions *= atom_mask
...@@ -758,19 +758,27 @@ class StructureModule(nn.Module): ...@@ -758,19 +758,27 @@ class StructureModule(nn.Module):
def _init_residue_constants(self, device): def _init_residue_constants(self, device):
if(self.default_frames is None): if(self.default_frames is None):
self.default_frames = torch.tensor( self.default_frames = torch.tensor(
restype_rigid_group_default_frame, device=device, restype_rigid_group_default_frame,
device=device,
requires_grad=False,
) )
if(self.group_idx is None): if(self.group_idx is None):
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, device=device, restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
) )
if(self.atom_mask is None): if(self.atom_mask is None):
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, device=device, restype_atom14_mask,
device=device,
requires_grad=False,
) )
if(self.lit_positions is None): if(self.lit_positions is None):
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, device=device, restype_atom14_rigid_group_positions,
device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, t, alpha, f):
......
...@@ -366,6 +366,7 @@ residue_atoms = { ...@@ -366,6 +366,7 @@ residue_atoms = {
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of # in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours) # the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = { residue_atom_renaming_swaps = {
'ASP': {'OD1': 'OD2'}, 'ASP': {'OD1': 'OD2'},
'GLU': {'OE1': 'OE2'}, 'GLU': {'OE1': 'OE2'},
...@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5, ...@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5,
'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
'stddev': restype_atom14_bond_stddev, # shape (21,14,14) 'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
} }
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = (
np.tile(np.arange(14, dtype=np.int), (21, 1))
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = (
atom2_idx
)
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = (
atom1_idx
)
_make_atom14_ambiguity_feats()
...@@ -335,3 +335,15 @@ def quat_to_rot( ...@@ -335,3 +335,15 @@ def quat_to_rot(
# [*, 3, 3] # [*, 3, 3]
return torch.sum(quat, dim=(-3, -4)) return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector):
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = torch.zeros(
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Dict
import openfold.np.residue_constants as residue_constants import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
batched_gather, batched_gather,
...@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import ( ...@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import (
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = (aatype == residue_constants.restype_order['G']) is_gly = (aatype == rc.restype_order['G'])
ca_idx = residue_constants.atom_order['CA'] ca_idx = rc.atom_order['CA']
cb_idx = residue_constants.atom_order['CB'] cb_idx = rc.atom_order['CB']
pseudo_beta = torch.where( pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :], all_atom_positions[..., ca_idx, :],
...@@ -52,18 +52,18 @@ def get_chi_atom_indices(): ...@@ -52,18 +52,18 @@ def get_chi_atom_indices():
Returns: Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0. positions indices are by default set to 0.
""" """
chi_atom_indices = [] chi_atom_indices = []
for residue_name in residue_constants.restypes: for residue_name in rc.restypes:
residue_name = residue_constants.restype_1to3[residue_name] residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = [] atom_indices = []
for chi_angle in residue_chi_angles: for chi_angle in residue_chi_angles:
atom_indices.append( atom_indices.append(
[residue_constants.atom_order[atom] for atom in chi_angle]) [rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)): for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices) chi_atom_indices.append(atom_indices)
...@@ -74,6 +74,7 @@ def get_chi_atom_indices(): ...@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def compute_residx(batch): def compute_residx(batch):
out = {}
float_type = batch["seq_mask"].dtype float_type = batch["seq_mask"].dtype
aatype = batch["aatype"] aatype = batch["aatype"]
...@@ -81,19 +82,20 @@ def compute_residx(batch): ...@@ -81,19 +82,20 @@ def compute_residx(batch):
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = [] restype_atom14_mask = []
for rt in residue_constants.restypes: for rt in rc.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[ atom_names = rc.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]] rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([ restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0) (rc.atom_order[name] if name else 0)
for name in atom_names for name in atom_names
]) ])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([ restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types for name in rc.atom_types
]) ])
restype_atom14_mask.append( restype_atom14_mask.append(
...@@ -118,24 +120,27 @@ def compute_residx(batch): ...@@ -118,24 +120,27 @@ def compute_residx(batch):
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype] residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype] residx_atom14_mask = restype_atom14_mask[aatype]
batch['atom14_atom_exists'] = residx_atom14_mask out["residx_atom14_to_atom37"] = residx_atom14_to_atom37
batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37 out["atom14_atom_exists"] = residx_atom14_mask
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype] residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14
out["residx_atom37_to_atom14"] = residx_atom37_to_atom14
# create the corresponding mask # create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=float_type) restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(residue_constants.restypes): for restype, restype_letter in enumerate(rc.restypes):
restype_name = residue_constants.restype_1to3[restype_letter] restype_name = rc.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name] atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names: for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name] atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[aatype] residx_atom37_mask = restype_atom37_mask[aatype]
batch['atom37_atom_exists'] = residx_atom37_mask out["atom37_atom_exists"] = residx_atom37_mask
return out
def atom14_to_atom37(atom14, batch): def atom14_to_atom37(atom14, batch):
...@@ -225,9 +230,9 @@ def atom37_to_torsion_angles( ...@@ -225,9 +230,9 @@ def atom37_to_torsion_angles(
all_atom_pos, atom_indices, -2, len(atom_indices.shape[:-2]) all_atom_pos, atom_indices, -2, len(atom_indices.shape[:-2])
) )
chi_angles_mask = list(residue_constants.chi_angles_mask) chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.]) chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_pos.new_tensor(chi_angles_mask) chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :] chis_mask = chi_angles_mask[aatype, :]
...@@ -282,7 +287,7 @@ def atom37_to_torsion_angles( ...@@ -282,7 +287,7 @@ def atom37_to_torsion_angles(
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
residue_constants.chi_pi_periodic, rc.chi_pi_periodic,
)[aatype, ...] )[aatype, ...]
mirror_torsion_angles = torch.cat( mirror_torsion_angles = torch.cat(
...@@ -307,6 +312,7 @@ def atom37_to_frames( ...@@ -307,6 +312,7 @@ def atom37_to_frames(
aatype: torch.Tensor, aatype: torch.Tensor,
all_atom_positions: torch.Tensor, all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
...@@ -314,13 +320,14 @@ def atom37_to_frames( ...@@ -314,13 +320,14 @@ def atom37_to_frames(
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(residue_constants.restypes): for restype, restype_letter in enumerate(rc.restypes):
resname = residue_constants.restype_1to3[restype_letter] resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4): for chi_idx in range(4):
if(residue_constants.chi_angles_mask[restype][chi_idx]): if(rc.chi_angles_mask[restype][chi_idx]):
names = residue_constants.chi_angles_atoms[resname][chi_idx] names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[ restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :] = atom_names[1:] restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = torch.zeros( restype_rigidgroup_mask = torch.zeros(
(*aatype.shape[:-1], 21, 8), (*aatype.shape[:-1], 21, 8),
...@@ -330,9 +337,11 @@ def atom37_to_frames( ...@@ -330,9 +337,11 @@ def atom37_to_frames(
) )
restype_rigidgroup_mask[:, 0] = 1 restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1 restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask restype_rigidgroup_mask[:20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = residue_constants.atom_order.copy() lookuptable = rc.atom_order.copy()
lookuptable[''] = 0 lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x]) lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup( restype_rigidgroup_base_atom37_idx = lookup(
...@@ -349,7 +358,7 @@ def atom37_to_frames( ...@@ -349,7 +358,7 @@ def atom37_to_frames(
) )
residx_rigidgroup_base_atom37_idx = batched_gather( residx_rigidgroup_base_atom37_idx = batched_gather(
residx_rigidgroup_base_atom37_idx, restype_rigidgroup_base_atom37_idx,
aatype, aatype,
dim=-3, dim=-3,
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
...@@ -363,9 +372,9 @@ def atom37_to_frames( ...@@ -363,9 +372,9 @@ def atom37_to_frames(
) )
gt_frames = T.from_3_points( gt_frames = T.from_3_points(
point_on_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
point_on_xy_plane=base_atom_pos[..., 2, :], p_xy_plane=base_atom_pos[..., 2, :],
) )
group_exists = batched_gather( group_exists = batched_gather(
...@@ -381,33 +390,31 @@ def atom37_to_frames( ...@@ -381,33 +390,31 @@ def atom37_to_frames(
dim=-1, dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1]) no_batch_dims=len(all_atom_mask.shape[:-1])
) )
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(3, device=aatype.device, requires_grad=False) rots = torch.eye(3, device=aatype.device, requires_grad=False)
rots = rots.view(*((1,) * batch_dims), 1, 3, 3) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots = rots.expand(*((-1,) * batch_dims), 8, -1, -1)
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
) )
restype_rigidgroup_rots = torch.eye( restype_rigidgroup_rots = torch.eye(
3, device=aatype.device, requires_grad=False 3, device=aatype.device, requires_grad=False
) )
restype_rigidgroup_rots = restype_rigidgroup_rots.view( restype_rigidgroup_rots = torch.tile(
*((1,) * batch_dims), 1, 1, 3, 3 restype_rigidgroup_rots,
) (*((1,) * batch_dims), 21, 8, 1, 1),
restype_rigidgroup_rots = restype_rigidgroup_rots.expand(
*((-1,) * batch_dims), 21, 8, 3, 3
) )
for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = residue_constants.restype_order[ restype = rc.restype_order[
residue_constants.restype3to1[resname] rc.restype_3to1[resname]
] ]
chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
...@@ -419,18 +426,17 @@ def atom37_to_frames( ...@@ -419,18 +426,17 @@ def atom37_to_frames(
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
residx_rigidgroup_ambiguity_rot = utils.batched_gather( residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots, restype_rigidgroup_rots,
aatype, aatype,
dim=-4, dim=-4,
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
alt_gt_frames = gt_frames.apply(T(residx_rigidgroup_ambiguity_rot, None)) alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
# TODO: Verify that I can get away with skipping the flat12 format gt_frames_tensor = gt_frames.to_4x4()
gt_frames_tensor = gt_frames.to_tensor() alt_gt_frames_tensor = alt_gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_tensor()
return { return {
'rigidgroups_gt_frames': gt_frames_tensor, 'rigidgroups_gt_frames': gt_frames_tensor,
...@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot( aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"], residue_constants.restype_num + 2, batch["template_aatype"], rc.restype_num + 2,
) )
n_res = batch["template_aatype"].shape[-1] n_res = batch["template_aatype"].shape[-1]
...@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
) )
) )
n, ca, c = [residue_constants.atom_order[a] for a in ['N', 'CA', 'C']] n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
t_aa_masks = batch["template_all_atom_masks"] t_aa_masks = batch["template_all_atom_masks"]
template_mask = ( template_mask = (
...@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch): ...@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py # adapted from model/tf/data_transforms.py
def build_msa_feat(protein): def build_msa_feat(batch):
"""Create and concatenate MSA features.""" """Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping # Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets. # for compatibility with domain datasets.
...@@ -544,7 +550,7 @@ def build_msa_feat(protein): ...@@ -544,7 +550,7 @@ def build_msa_feat(protein):
deletion_value.unsqueeze(-1), deletion_value.unsqueeze(-1),
] ]
if 'cluster_profile' in protein: if 'cluster_profile' in batch:
deletion_mean_value = ( deletion_mean_value = (
tf.atan(batch['cluster_deletion_mean'] / 3.) * (2. / np.pi)) tf.atan(batch['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([ msa_feat.extend([
...@@ -560,4 +566,53 @@ def build_msa_feat(protein): ...@@ -560,4 +566,53 @@ def build_msa_feat(protein):
batch['msa_feat'] = torch.cat(msa_feat, dim=-1) batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
batch['target_feat'] = torch.cat(target_feat, dim=-1) batch['target_feat'] = torch.cat(target_feat, dim=-1)
return protein return batch
def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms, requires_grad=False,
)
)
atom14_atom_is_ambiguous = ambiguous_atoms[batch["aatype"], ...]
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(
swap_mat, requires_grad=False
)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
batch["atom14_gt_positions"][..., None, :] * swap_mat[..., None],
dim=-3
)
)
atom14_alt_gt_exists = (
torch.sum(
batch["atom14_gt_exists"][..., None] * swap_mat, dim=-2
)
)
return {
"atom14_atom_is_ambiguous": atom14_atom_is_ambiguous,
"atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists,
}
...@@ -89,15 +89,15 @@ def compute_fape( ...@@ -89,15 +89,15 @@ def compute_fape(
target_positions[..., None, :, :], target_positions[..., None, :, :],
) )
error_dist = torch.sqrt( error_dist = torch.sqrt(
(pred_positions - target_positions)**2 + eps torch.sum((local_pred_pos - local_target_pos)**2, dim=-1) + eps
) )
if(l1_clamp_distance is not None): if(l1_clamp_distance is not None):
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale normed_error = error_dist / length_scale
normed_error *= frames_mask.unsqueeze(-1) normed_error *= frames_mask[..., None]
normed_error *= positions_mask.unsqueeze(-2) normed_error *= positions_mask[..., None, :]
norm_factor = ( norm_factor = (
torch.sum(frames_mask, dim=-1) * torch.sum(frames_mask, dim=-1) *
...@@ -109,67 +109,71 @@ def compute_fape( ...@@ -109,67 +109,71 @@ def compute_fape(
return normed_error return normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
def backbone_loss( def backbone_loss(
batch: Dict[str, torch.Tensor], backbone_affine_tensor: torch.Tensor,
pred_aff_tensor: torch.Tensor, backbone_affine_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10., clamp_distance: float = 10.,
loss_unit_distance: float = 10., loss_unit_distance: float = 10.,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_aff = T.from_tensor(pred_aff_tensor) pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(batch["backbone_affine_tensor"]) gt_aff = T.from_tensor(backbone_affine_tensor)
backbone_mask = batch["backbone_affine_mask"]
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff, gt_aff[..., None, :],
backbone_mask, backbone_affine_mask[..., None, :],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff.get_trans(), gt_aff[..., None, :].get_trans(),
backbone_mask, backbone_affine_mask[..., None, :],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
) )
if('use_clamped_fape' in batch): if(use_clamped_fape is not None):
use_clamped_fape = batch["use_clamped_fape"]
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff, gt_aff[..., None, :],
backbone_mask, backbone_affine_mask[..., None, :],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff.get_trans(), gt_aff[..., None, :].get_trans(),
backbone_mask, backbone_affine_mask[..., None, :],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
) )
fape_loss = ( fape_loss = (
fape_loss * use_clamped_fape + fape_loss * use_clamped_fape +
fape_loss_unclamped * (1 - use_clamped_fape) unclamped_fape_loss * (1 - use_clamped_fape)
) )
return torch.mean(fape_loss, dim=-1) return torch.mean(fape_loss, dim=-1)
def sidechain_loss( def sidechain_loss(
sidechain_frames, sidechain_frames: torch.Tensor,
sidechain_atom_pos, sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames, rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames, rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists, rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions, renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists, renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better, alt_naming_is_better: torch.Tensor,
clamp_distance=10., clamp_distance: float = 10.,
length_scale=10., length_scale: float = 10.,
): **kwargs,
) -> torch.Tensor:
renamed_gt_frames = ( renamed_gt_frames = (
(1. - alt_naming_is_better[..., None, None, None, None]) * (1. - alt_naming_is_better[..., None, None, None, None]) *
gt_frames + rigidgroups_gt_frames +
alt_naming_is_better[..., None, None, None, None] * alt_naming_is_better[..., None, None, None, None] *
alt_gt_frames rigidgroups_alt_gt_frames
) )
sidechain_frames = T.from_4x4(sidechain_frames)
renamed_gt_frames = T.from_4x4(renamed_gt_frames) renamed_gt_frames = T.from_4x4(renamed_gt_frames)
fape = compute_fape( fape = compute_fape(
...@@ -192,16 +196,13 @@ def fape_loss( ...@@ -192,16 +196,13 @@ def fape_loss(
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
) -> torch.Tensor: ) -> torch.Tensor:
bb_loss = backbone_loss( bb_loss = backbone_loss(
batch, out["sm"]["frames"][-1], **config.backbone traj=out["sm"]["frames"], **{**batch, **config.backbone},
) )
sc_loss = sidechain_loss( sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"], out["sm"]["sidechain_frames"],
out["sm"]["positions"], out["sm"]["positions"],
{ **{**batch, **config.sidechain}
**batch,
**config.sidechain,
},
) )
return ( return (
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation
def random_template_feats(n_templ, n, batch_size=None): def random_template_feats(n_templ, n, batch_size=None):
...@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None): ...@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None):
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch return batch
def random_extra_msa_feats(n_extra, n, batch_size=None): def random_extra_msa_feats(n_extra, n, batch_size=None):
b = [] b = []
if(batch_size is not None): if(batch_size is not None):
...@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None): ...@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32), np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
} }
return batch return batch
def random_affine_vectors(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
affines = np.zeros((prod_dim, 7))
for i in range(prod_dim):
affines[i, :4] = Rotation.random(random_state=42).as_quat()
affines[i, 4:] = np.random.rand(3,)
return affines.reshape(*dim, 7)
def random_affine_4x4s(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
affines = np.zeros((prod_dim, 4, 4))
for i in range(prod_dim):
affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix()
affines[i, :3, 3] = np.random.rand(3,)
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment