Commit 893fe372 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Get FP16 training working

parent dd06b323
...@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -44,7 +44,7 @@ def model_config(name, train=False, low_prec=False):
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if(train): if(train):
c.globals.model.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
if(low_prec): if(low_prec):
...@@ -137,7 +137,7 @@ config = mlc.ConfigDict({ ...@@ -137,7 +137,7 @@ config = mlc.ConfigDict({
}, },
"inf": 1e9, "inf": 1e9,
"eps": eps,#1e-6, "eps": eps,#1e-6,
"enabled": True, "enabled": False,#True,
"embed_angles": True, "embed_angles": True,
}, },
"extra_msa": { "extra_msa": {
...@@ -239,7 +239,7 @@ config = mlc.ConfigDict({ ...@@ -239,7 +239,7 @@ config = mlc.ConfigDict({
"max_bin": 21.6875, "max_bin": 21.6875,
"no_bins": 64, "no_bins": 64,
"eps": eps,#1e-6, "eps": eps,#1e-6,
"weight": 0.,#0.3, "weight": 0.3,
}, },
"experimentally_resolved": { "experimentally_resolved": {
"eps": eps,#1e-8, "eps": eps,#1e-8,
...@@ -267,17 +267,17 @@ config = mlc.ConfigDict({ ...@@ -267,17 +267,17 @@ config = mlc.ConfigDict({
"cutoff": 15., "cutoff": 15.,
"no_bins": 50, "no_bins": 50,
"eps": eps,#1e-10, "eps": eps,#1e-10,
"weight": 0.,#0.01, "weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"eps": eps,#1e-8, "eps": eps,#1e-8,
"weight": 0.,#2.0, "weight": 2.0,
}, },
"supervised_chi": { "supervised_chi": {
"chi_weight": 0.5, "chi_weight": 0.5,
"angle_norm_weight": 0.01, "angle_norm_weight": 0.01,
"eps": eps,#1e-6, "eps": eps,#1e-6,
"weight": 0.,#1.0, "weight": 1.0,
}, },
"violation": { "violation": {
"violation_tolerance_factor": 12.0, "violation_tolerance_factor": 12.0,
......
...@@ -389,6 +389,7 @@ class AlphaFold(nn.Module): ...@@ -389,6 +389,7 @@ class AlphaFold(nn.Module):
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.no_cycles):
...@@ -400,8 +401,10 @@ class AlphaFold(nn.Module): ...@@ -400,8 +401,10 @@ class AlphaFold(nn.Module):
is_final_iter = (cycle_no == (self.config.no_cycles - 1)) is_final_iter = (cycle_no == (self.config.no_cycles - 1))
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766 # Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter and torch.is_autocast_enabled()): if(is_final_iter):
torch.clear_autocast_cache() self._enable_activation_checkpointing()
if(torch.is_autocast_enabled()):
torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev, feats, m_1_prev, z_prev, x_prev,
......
...@@ -257,8 +257,6 @@ class Attention(nn.Module): ...@@ -257,8 +257,6 @@ class Attention(nn.Module):
a = a + b a = a + b
a = self.softmax(a) a = self.softmax(a)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
......
...@@ -26,6 +26,10 @@ from openfold.np.residue_constants import ( ...@@ -26,6 +26,10 @@ from openfold.np.residue_constants import (
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
permute_final_dims, permute_final_dims,
...@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module): ...@@ -305,7 +309,7 @@ class InvariantPointAttention(nn.Module):
pt_att = pt_att ** 2 pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q] # [*, N_res, N_res, H, P_q]
pt_att = torch.sum(pt_att, dim=-1) pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view( head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1)) *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
) )
...@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module): ...@@ -358,7 +362,7 @@ class InvariantPointAttention(nn.Module):
) )
# [*, N_res, H * P_v, 3] # [*, N_res, H * P_v, 3]
o_pt = o_pt.view(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z) o_pair = torch.matmul(a.transpose(-2, -3), z)
...@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module): ...@@ -409,14 +413,17 @@ class BackboneUpdate(nn.Module):
quats, trans = params[...,:3], params[...,3:] quats, trans = params[...,:3], params[...,3:]
# [*] # [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1) norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# As many ones as there are dimensions in quats # [*, 3]
ones = s.new_ones((1,) * len(quats.shape)) ones = (
s.new_ones((1,) * len(quats.shape)).expand(quats.shape[:-1] + (1,))
)
# [*, 4] # [*, 4]
quats = torch.cat((ones.expand(*quats.shape[:-1], 1), quats), dim=-1) quats = torch.cat([ones, quats], dim=-1)
quats = quats / norm_denom.unsqueeze(-1) quats = quats / norm_denom[..., None]
# [*, 3, 3] # [*, 3, 3]
rots = quat_to_rot(quats) rots = quat_to_rot(quats)
...@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module): ...@@ -424,105 +431,6 @@ class BackboneUpdate(nn.Module):
return T(rots, trans) return T(rots, trans)
def _torsion_angles_to_frames(t, alpha, f, rrgdf):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[f,...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha],
dim=-2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_t.rots.shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None)
all_frames = default_t.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat([
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
], dim=-1,
)
all_frames_to_global = t[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def _frames_and_literature_positions_to_atom14_pos(
t,
f,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
# [*, N, 14, 4, 4]
default_4x4 = default_frames[f, ...]
# [*, N, 14]
group_mask = group_idx[f, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask, num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1]
atom_mask = atom_mask[f,...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[f, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
class StructureModuleTransitionLayer(nn.Module): class StructureModuleTransitionLayer(nn.Module):
def __init__(self, c): def __init__(self, c):
super(StructureModuleTransitionLayer, self).__init__() super(StructureModuleTransitionLayer, self).__init__()
...@@ -664,6 +572,7 @@ class StructureModule(nn.Module): ...@@ -664,6 +572,7 @@ class StructureModule(nn.Module):
self.no_qk_points, self.no_qk_points,
self.no_v_points, self.no_v_points,
inf=self.inf, inf=self.inf,
eps=self.epsilon,
) )
self.ipa_dropout = nn.Dropout(self.dropout_rate) self.ipa_dropout = nn.Dropout(self.dropout_rate)
...@@ -791,7 +700,7 @@ class StructureModule(nn.Module): ...@@ -791,7 +700,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device) self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames) return torsion_angles_to_frames(t, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self, def frames_and_literature_positions_to_atom14_pos(self,
t, # [*, N, 8] t, # [*, N, 8]
...@@ -799,7 +708,7 @@ class StructureModule(nn.Module): ...@@ -799,7 +708,7 @@ class StructureModule(nn.Module):
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device) self._init_residue_constants(t.rots.dtype, t.rots.device)
return _frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
t, t,
f, f,
self.default_frames, self.default_frames,
......
...@@ -188,17 +188,29 @@ class T: ...@@ -188,17 +188,29 @@ class T:
@staticmethod @staticmethod
def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8): def from_3_points(p_neg_x_axis, origin, p_xy_plane, eps=1e-8):
e0 = origin - p_neg_x_axis p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
e1 = p_xy_plane - origin origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = e0 / torch.sqrt(torch.sum(e0 ** 2, dim=-1, keepdims=True) + eps)
e1 = e1 - e0 * torch.sum(e0 * e1, dim=-1, keepdims=True) e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = e1 / torch.sqrt(torch.sum(e1 ** 2, dim=-1, keepdims=True) + eps) e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
e2 = torch.cross(e0, e1)
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
rots = torch.stack([e0, e1, e2], dim=-1) e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
return T(rots, origin) e1 = [c1 - c2 * dot for c1, c2 in zip(e1, e0)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
return T(rots, torch.stack(origin, dim=-1))
@staticmethod @staticmethod
def concat(ts, dim): def concat(ts, dim):
...@@ -294,6 +306,9 @@ class T: ...@@ -294,6 +306,9 @@ class T:
return T(rots, translation) return T(rots, translation)
def cuda(self):
return T(self.rots.cuda(), self.trans.cuda())
_quat_elements = ['a', 'b', 'c', 'd'] _quat_elements = ['a', 'b', 'c', 'd']
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
...@@ -325,10 +340,11 @@ def quat_to_rot( ...@@ -325,10 +340,11 @@ def quat_to_rot(
# [*, 4, 4] # [*, 4, 4]
quat = quat[..., None] * quat[..., None, :] quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_qtr_mat) mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3] # [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3)) shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3] # [*, 3, 3]
......
...@@ -70,11 +70,8 @@ def checkpoint_blocks( ...@@ -70,11 +70,8 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
#print(len(args)) #args = checkpoint(chunker(s, e), *args)
#for a in args: args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
# print(a.requires_grad)
args = checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args) args = wrap(args)
return args return args
...@@ -173,6 +173,11 @@ def atom37_to_torsion_angles( ...@@ -173,6 +173,11 @@ def atom37_to_torsion_angles(
**kwargs, **kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" """
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args: Args:
aatype: aatype:
[*, N_res] residue indices [*, N_res] residue indices
...@@ -228,10 +233,10 @@ def atom37_to_torsion_angles( ...@@ -228,10 +233,10 @@ def atom37_to_torsion_angles(
) )
phi_mask = ( phi_mask = (
prev_all_atom_mask[..., 2] * prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1) torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
) )
psi_mask = ( psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1) * torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4] all_atom_mask[..., 4]
) )
...@@ -256,7 +261,9 @@ def atom37_to_torsion_angles( ...@@ -256,7 +261,9 @@ def atom37_to_torsion_angles(
dim=-1, dim=-1,
no_batch_dims=len(atom_indices.shape[:-2]) no_batch_dims=len(atom_indices.shape[:-2])
) )
chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1) chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat( torsions_atom_pos = torch.cat(
...@@ -281,6 +288,7 @@ def atom37_to_torsion_angles( ...@@ -281,6 +288,7 @@ def atom37_to_torsion_angles(
torsions_atom_pos[..., 1, :], torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :], torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :], torsions_atom_pos[..., 0, :],
eps=eps,
) )
fourth_atom_rel_pos = torsion_frames.invert().apply( fourth_atom_rel_pos = torsion_frames.invert().apply(
...@@ -290,15 +298,19 @@ def atom37_to_torsion_angles( ...@@ -290,15 +298,19 @@ def atom37_to_torsion_angles(
torsion_angles_sin_cos = torch.stack( torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
) )
denom = torch.sqrt( denom = torch.sqrt(
torch.sum( torch.sum(
torch.square(torsion_angles_sin_cos), dim=-1, keepdims=True torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + eps ) + eps
) )
torsion_angles_sin_cos = torsion_angles_sin_cos / denom torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * torch.tensor( torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.], device=aatype.device, [1., 1., -1., 1., 1., 1., 1.],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
...@@ -327,6 +339,7 @@ def atom37_to_frames( ...@@ -327,6 +339,7 @@ def atom37_to_frames(
aatype: torch.Tensor, aatype: torch.Tensor,
all_atom_positions: torch.Tensor, all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
eps: float,
**kwargs, **kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
...@@ -387,6 +400,7 @@ def atom37_to_frames( ...@@ -387,6 +400,7 @@ def atom37_to_frames(
p_neg_x_axis=base_atom_pos[..., 0, :], p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :], origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :], p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
) )
group_exists = batched_gather( group_exists = batched_gather(
...@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None: ...@@ -638,3 +652,106 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"atom14_alt_gt_positions": atom14_alt_gt_positions, "atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists, "atom14_alt_gt_exists": atom14_alt_gt_exists,
} }
def torsion_angles_to_frames(
t: T,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_t = T.from_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha],
dim=-2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_t.rots.shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = T(all_rots, None)
all_frames = default_t.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat([
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
], dim=-1,
)
all_frames_to_global = t[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
t: T,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
# [*, N, 14, 4, 4]
default_4x4 = default_frames[aatype, ...]
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask, num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
...@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -417,7 +417,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
if("_ptm" in version): if("_ptm" in version):
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": "logits":
LinearParams(model.aux_heads.tm_score.linear) LinearParams(model.aux_heads.tm.linear)
} }
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
......
...@@ -273,7 +273,6 @@ def supervised_chi_loss( ...@@ -273,7 +273,6 @@ def supervised_chi_loss(
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum( sq_chi_error = torch.sum(
(true_chi - pred_angles)**2, dim=-1 (true_chi - pred_angles)**2, dim=-1
) )
...@@ -498,11 +497,11 @@ def tm_loss( ...@@ -498,11 +497,11 @@ def tm_loss(
) )
loss = torch.sum(errors * square_mask, dim=-1) loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.1 # hack to help FP16 training along scale = 0.5 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None] loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss / scale loss = loss * scale
loss = loss * ( loss = loss * (
(resolution >= min_resolution) & (resolution >= min_resolution) &
...@@ -744,7 +743,7 @@ def between_residue_clash_loss( ...@@ -744,7 +743,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash. # Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot( c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2.), num_classes=14 residue_index.new_tensor(2), num_classes=14
) )
c_one_hot = c_one_hot.reshape( c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
...@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1319,11 +1318,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# ) # )
loss = errors * bert_mask loss = errors * bert_mask
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
scale = 0.1 scale = 0.5
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None] loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss / scale loss = loss * scale
return loss return loss
...@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1352,7 +1351,7 @@ class AlphaFoldLoss(nn.Module):
)) ))
if("backbone_affine_tensor" not in batch.keys()): if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(**batch)) batch.update(feats.atom37_to_frames(eps=self.config.eps, **batch))
# TODO: Verify that this is correct # TODO: Verify that this is correct
batch["backbone_affine_tensor"] = ( batch["backbone_affine_tensor"] = (
...@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module): ...@@ -1363,16 +1362,19 @@ class AlphaFoldLoss(nn.Module):
) )
if("chi_angles_sin_cos" not in batch.keys()): if("chi_angles_sin_cos" not in batch.keys()):
batch.update(feats.atom37_to_torsion_angles( with torch.no_grad():
**batch, batch.update(feats.atom37_to_torsion_angles(
eps=self.config.eps, aatype=batch["aatype"],
)) all_atom_positions=batch["all_atom_positions"].double(),
all_atom_mask=batch["all_atom_mask"].double(),
# TODO: Verify that this is correct eps=self.config.eps,
batch["chi_angles_sin_cos"] = ( ))
batch["torsion_angles_sin_cos"][..., 3:, :]
) # TODO: Verify that this is correct
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:] batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
).to(batch["all_atom_mask"].dtype)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:].to(batch["all_atom_mask"].dtype)
loss_fns = { loss_fns = {
"distogram": "distogram":
......
...@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]): ...@@ -25,11 +25,11 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return tensor.permute(first_inds + [zero_index + i for i in inds]) return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(tensor: torch.Tensor, no_dims: int): def flatten_final_dims(t: torch.Tensor, no_dims: int):
return tensor.reshape(tensor.shape[:-no_dims] + (-1,)) return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-10): def masked_mean(mask, value, dim, eps=1e-4):
mask = mask.expand(*value.shape) mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment