"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "e97493eb0065285c2775bfb5fcee7cd821f08842"
Commit ab0c6977 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue debugging loss functions, remove in-place ops

parent 85c0a9a9
...@@ -113,6 +113,7 @@ class InputEmbedder(nn.Module): ...@@ -113,6 +113,7 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri) pair_emb += self.relpos(ri)
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
......
...@@ -94,7 +94,7 @@ class MSATransition(nn.Module): ...@@ -94,7 +94,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m) m = self.layer_norm(m)
inp = {"m": m, "mask": mask} inp = {"m": m, "mask": mask}
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
m = chunk_layer( m = chunk_layer(
self._transition, self._transition,
inp, inp,
...@@ -132,6 +132,7 @@ class EvoformerBlock(nn.Module): ...@@ -132,6 +132,7 @@ class EvoformerBlock(nn.Module):
c_z=c_z, c_z=c_z,
c_hidden=c_hidden_msa_att, c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa, no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
......
...@@ -108,7 +108,7 @@ class AlphaFold(nn.Module): ...@@ -108,7 +108,7 @@ class AlphaFold(nn.Module):
def embed_templates(self, batch, z, pair_mask, templ_dim): def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[-2] n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -155,14 +155,14 @@ class AlphaFold(nn.Module): ...@@ -155,14 +155,14 @@ class AlphaFold(nn.Module):
partial(torch.cat, dim=templ_dim), partial(torch.cat, dim=templ_dim),
template_embeds, template_embeds,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
template_embeds["pair"], template_embeds["pair"],
z, z,
template_mask=batch["template_mask"] template_mask=batch["template_mask"]
) )
t *= torch.sum(batch["template_mask"]) > 0 t = t * torch.sum(batch["template_mask"]) > 0
return { return {
"template_angle_embedding": a, "template_angle_embedding": a,
...@@ -297,7 +297,7 @@ class AlphaFold(nn.Module): ...@@ -297,7 +297,7 @@ class AlphaFold(nn.Module):
m[..., 0, :, :] += m_1_prev_emb m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z] # [*, N, N, C_z]
z += z_prev_emb z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled): if(self.config.template.enabled):
...@@ -312,7 +312,7 @@ class AlphaFold(nn.Module): ...@@ -312,7 +312,7 @@ class AlphaFold(nn.Module):
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z += template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles): if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
......
...@@ -125,7 +125,7 @@ class MSAAttention(nn.Module): ...@@ -125,7 +125,7 @@ class MSAAttention(nn.Module):
"v_x": m, "v_x": m,
"biases": biases "biases": biases
} }
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
m = chunk_layer( m = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
...@@ -142,7 +142,7 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -142,7 +142,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
""" """
Implements Algorithm 7. Implements Algorithm 7.
""" """
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9):
""" """
Args: Args:
c_m: c_m:
...@@ -161,7 +161,8 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -161,7 +161,8 @@ class MSARowAttentionWithPairBias(MSAAttention):
c_hidden, c_hidden,
no_heads, no_heads,
pair_bias=True, pair_bias=True,
c_z=c_z, c_z=c_z,
chunk_size=chunk_size,
inf=inf, inf=inf,
) )
...@@ -259,7 +260,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -259,7 +260,7 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(q) q = self.linear_q(q)
q *= self.c_hidden ** (-0.5) q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1) q = q.view(*q.shape[:-1], self.no_heads, -1)
...@@ -274,7 +275,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -274,7 +275,7 @@ class MSAColumnGlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a = a + bias
a = self.softmax(a) a = self.softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
...@@ -318,7 +319,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -318,7 +319,7 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
m = chunk_layer( m = chunk_layer(
self.global_attention, self.global_attention,
mha_input, mha_input,
......
...@@ -83,7 +83,7 @@ class OuterProductMean(nn.Module): ...@@ -83,7 +83,7 @@ class OuterProductMean(nn.Module):
a = a.transpose(-2, -3) a = a.transpose(-2, -3)
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
# Since the "batch dim" in this case is not a true batch dimension # Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to # (in that the shape of the output depends on it), we need to
# iterate over it ourselves # iterate over it ourselves
...@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module): ...@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
norm = torch.einsum("...abc,...adc->...bdc", mask, mask) norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer /= self.eps + norm outer = outer / self.eps + norm
return outer return outer
......
...@@ -73,7 +73,7 @@ class PairTransition(nn.Module): ...@@ -73,7 +73,7 @@ class PairTransition(nn.Module):
z = self.layer_norm(z) z = self.layer_norm(z)
inp = {"z": z, "mask": mask} inp = {"z": z, "mask": mask}
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
z = chunk_layer( z = chunk_layer(
self._transition, self._transition,
inp, inp,
......
...@@ -251,10 +251,10 @@ class Attention(nn.Module): ...@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K] permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K]
) )
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm a = a * norm
if(biases is not None): if(biases is not None):
for b in biases: for b in biases:
a += b a = a + b
a = self.softmax(a) a = self.softmax(a)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
......
...@@ -129,10 +129,11 @@ class AngleResnet(nn.Module): ...@@ -129,10 +129,11 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2] # [*, no_angles * 2]
s = self.linear_out(s) s = self.linear_out(s)
unnormalized_s = s
# [*, no_angles, 2] # [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2) s = s.view(*s.shape[:-1], -1, 2)
unnormalized_s = s
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
torch.clamp( torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True), torch.sum(s ** 2, dim=-1, keepdims=True),
...@@ -295,8 +296,8 @@ class InvariantPointAttention(nn.Module): ...@@ -295,8 +296,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden] permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res] permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res]
) )
a *= math.sqrt(1. / (3 * self.c_hidden)) a = a + math.sqrt(1. / (3 * self.c_hidden))
a += math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1) a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1)
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -307,7 +308,9 @@ class InvariantPointAttention(nn.Module): ...@@ -307,7 +308,9 @@ class InvariantPointAttention(nn.Module):
head_weights = self.softplus(self.head_weights).view( head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1)) *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
) )
head_weights *= math.sqrt(1. / (3 * (self.no_qk_points * 9. / 2))) head_weights = (
head_weights * math.sqrt(1. / (3 * (self.no_qk_points * 9. / 2)))
)
pt_att = pt_att * head_weights pt_att = pt_att * head_weights
# [*, N_res, N_res, H] # [*, N_res, N_res, H]
...@@ -319,8 +322,8 @@ class InvariantPointAttention(nn.Module): ...@@ -319,8 +322,8 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1) pt_att = permute_final_dims(pt_att, 2, 0, 1)
a += pt_att a = a + pt_att
a += square_mask.unsqueeze(-3) a = a + square_mask.unsqueeze(-3)
a = self.softmax(a) a = self.softmax(a)
################ ################
...@@ -510,7 +513,7 @@ def _frames_and_literature_positions_to_atom14_pos( ...@@ -510,7 +513,7 @@ def _frames_and_literature_positions_to_atom14_pos(
# [*, N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[f, ...] lit_positions = lit_positions[f, ...]
pred_positions = t_atoms_to_global.apply(lit_positions) pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions *= atom_mask pred_positions = pred_positions * atom_mask
return pred_positions return pred_positions
......
...@@ -108,7 +108,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -108,7 +108,7 @@ class TemplatePointwiseAttention(nn.Module):
"v_x": t, "v_x": t,
"biases": [bias], "biases": [bias],
} }
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
z = chunk_layer( z = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
......
...@@ -102,7 +102,7 @@ class TriangleAttention(nn.Module): ...@@ -102,7 +102,7 @@ class TriangleAttention(nn.Module):
"v_x": x, "v_x": x,
"biases": [mask_bias, triangle_bias], "biases": [mask_bias, triangle_bias],
} }
if(not self.training and self.chunk_size is not None): if(self.chunk_size is not None):
x = chunk_layer( x = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
......
...@@ -70,5 +70,6 @@ def checkpoint_blocks( ...@@ -70,5 +70,6 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args) args = deepspeed.checkpointing.checkpoint(chunker(s, e), args)
args = wrap(args)
return args return args
...@@ -158,15 +158,16 @@ def atom14_to_atom37(atom14, batch): ...@@ -158,15 +158,16 @@ def atom14_to_atom37(atom14, batch):
def atom37_to_torsion_angles( def atom37_to_torsion_angles(
aatype: torch.Tensor, aatype: torch.Tensor,
all_atom_pos: torch.Tensor, all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
eps: float = 1e-8, eps: float = 1e-8,
**kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" """
Args: Args:
aatype: aatype:
[*, N_res] residue indices [*, N_res] residue indices
all_atom_pos: all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37 [*, N_res, 37, 3] atom positions (in atom37
format) format)
all_atom_mask: all_atom_mask:
...@@ -183,28 +184,32 @@ def atom37_to_torsion_angles( ...@@ -183,28 +184,32 @@ def atom37_to_torsion_angles(
""" """
aatype = torch.clamp(aatype, max=20) aatype = torch.clamp(aatype, max=20)
pad = all_atom_pos.new_zeros([*all_atom_pos.shape[:-3], 1, 37, 3]) pad = all_atom_positions.new_zeros(
prev_all_atom_pos = torch.cat([pad, all_atom_pos[..., :-1, :, :]], dim=-3) [*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat( pre_omega_atom_pos = torch.cat(
[ [
prev_all_atom_pos[..., 1:3, :], prev_all_atom_positions[..., 1:3, :],
all_atom_pos[..., :2, :] all_atom_positions[..., :2, :]
], dim=-2 ], dim=-2
) )
phi_atom_pos = torch.cat( phi_atom_pos = torch.cat(
[ [
prev_all_atom_pos[..., 2:3, :], prev_all_atom_positions[..., 2:3, :],
all_atom_pos[..., :3, :] all_atom_positions[..., :3, :]
], dim=-2 ], dim=-2
) )
psi_atom_pos = torch.cat( psi_atom_pos = torch.cat(
[ [
all_atom_pos[..., :3, :], all_atom_positions[..., :3, :],
all_atom_pos[..., 4:5, :] all_atom_positions[..., 4:5, :]
], dim=-2 ], dim=-2
) )
...@@ -227,7 +232,7 @@ def atom37_to_torsion_angles( ...@@ -227,7 +232,7 @@ def atom37_to_torsion_angles(
atom_indices = chi_atom_indices[..., aatype, :, :] atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather( chis_atom_pos = batched_gather(
all_atom_pos, atom_indices, -2, len(atom_indices.shape[:-2]) all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
) )
chi_angles_mask = list(rc.chi_angles_mask) chi_angles_mask = list(rc.chi_angles_mask)
...@@ -335,9 +340,9 @@ def atom37_to_frames( ...@@ -335,9 +340,9 @@ def atom37_to_frames(
device=aatype.device, device=aatype.device,
requires_grad=False requires_grad=False
) )
restype_rigidgroup_mask[:, 0] = 1 restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[:, 3] = 1 restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[:20, 4:] = ( restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask) all_atom_mask.new_tensor(rc.chi_angles_mask)
) )
......
...@@ -22,6 +22,7 @@ from typing import Dict, Optional ...@@ -22,6 +22,7 @@ from typing import Dict, Optional
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils import feats
from openfold.utils.affine_utils import T from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
...@@ -150,7 +151,9 @@ def backbone_loss( ...@@ -150,7 +151,9 @@ def backbone_loss(
unclamped_fape_loss * (1 - use_clamped_fape) unclamped_fape_loss * (1 - use_clamped_fape)
) )
return torch.mean(fape_loss, dim=-1) # Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=0)
return fape_loss
def sidechain_loss( def sidechain_loss(
...@@ -172,11 +175,10 @@ def sidechain_loss( ...@@ -172,11 +175,10 @@ def sidechain_loss(
alt_naming_is_better[..., None, None, None] * alt_naming_is_better[..., None, None, None] *
rigidgroups_alt_gt_frames rigidgroups_alt_gt_frames
) )
batch_dims = sidechain_frames.shape[:-5]
# Steamroll the inputs # Steamroll the inputs
sidechain_frames = sidechain_frames[-1] sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view( sidechain_frames = sidechain_frames.view(
*batch_dims, -1, 4, 4 *batch_dims, -1, 4, 4
) )
...@@ -198,7 +200,7 @@ def sidechain_loss( ...@@ -198,7 +200,7 @@ def sidechain_loss(
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view( renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(
*batch_dims, -1 *batch_dims, -1
) )
fape = compute_fape( fape = compute_fape(
sidechain_frames, sidechain_frames,
renamed_gt_frames, renamed_gt_frames,
...@@ -240,7 +242,7 @@ def supervised_chi_loss( ...@@ -240,7 +242,7 @@ def supervised_chi_loss(
aatype: torch.Tensor, aatype: torch.Tensor,
seq_mask: torch.Tensor, seq_mask: torch.Tensor,
chi_mask: torch.Tensor, chi_mask: torch.Tensor,
chi_angles: torch.Tensor, chi_angles_sin_cos: torch.Tensor,
chi_weight: float, chi_weight: float,
angle_norm_weight: float, angle_norm_weight: float,
eps=1e-6, eps=1e-6,
...@@ -256,24 +258,24 @@ def supervised_chi_loss( ...@@ -256,24 +258,24 @@ def supervised_chi_loss(
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
) )
true_chi = chi_angles true_chi = chi_angles_sin_cos.unsqueeze(-4)
sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum( sq_chi_error = torch.sum(
(sin_cos_true_chi - pred_angles)**2, dim=-1 (true_chi - pred_angles)**2, dim=-1
) )
sq_chi_error_shifted = torch.sum( sq_chi_error_shifted = torch.sum(
(sin_cos_true_chi_shifted - pred_angles)**2, dim=-1 (true_chi_shifted - pred_angles)**2, dim=-1
) )
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean( sq_chi_loss = masked_mean(
chi_mask, sq_chi_error, dim=(-1, -2) chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
) )
loss = 0 loss = 0
...@@ -283,8 +285,11 @@ def supervised_chi_loss( ...@@ -283,8 +285,11 @@ def supervised_chi_loss(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
) )
norm_error = torch.abs(angle_norm - 1.) norm_error = torch.abs(angle_norm - 1.)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
angle_norm_loss = masked_mean( angle_norm_loss = masked_mean(
seq_mask[..., None], norm_error, dim=(-1, -2) seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
) )
loss += angle_norm_weight * angle_norm_loss loss += angle_norm_weight * angle_norm_loss
...@@ -377,10 +382,10 @@ def lddt_loss( ...@@ -377,10 +382,10 @@ def lddt_loss(
) )
errors = softmax_cross_entropy(logits, lddt_ca_one_hot) errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1) all_atom_mask = all_atom_mask.squeeze(-1)
loss = ( loss = (
torch.sum(errors * all_atom_mask) / (torch.sum(all_atom_mask) + 1e-8) torch.sum(errors * all_atom_mask, dim=-1) /
(eps + torch.sum(all_atom_mask, dim=-1))
) )
loss *= ( loss *= (
...@@ -483,10 +488,10 @@ def tm_score( ...@@ -483,10 +488,10 @@ def tm_score(
def between_residue_bond_loss( def between_residue_bond_loss(
pred_atom_positions: torch.Tensor, # (N, 37(14), 3) pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
pred_atom_mask: torch.Tensor, # (N, 37(14)) pred_atom_mask: torch.Tensor, # (*, N, 37/14)
residue_index: torch.Tensor, # (N) residue_index: torch.Tensor, # (*, N)
aatype: torch.Tensor, # (N) aatype: torch.Tensor, # (*, N)
tolerance_factor_soft=12.0, tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0, tolerance_factor_hard=12.0,
eps=1e-6, eps=1e-6,
...@@ -561,7 +566,10 @@ def between_residue_bond_loss( ...@@ -561,7 +566,10 @@ def between_residue_bond_loss(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev c_n_bond_length_error - tolerance_factor_soft * gt_stddev
) )
mask = this_c_mask * next_n_mask * has_no_gap_mask mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = torch.sum(mask * c_n_loss_per_residue) / (torch.sum(mask) + eps) c_n_loss = (
torch.sum(mask * c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
)
c_n_violation_mask = mask * ( c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev) c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
) )
...@@ -589,7 +597,8 @@ def between_residue_bond_loss( ...@@ -589,7 +597,8 @@ def between_residue_bond_loss(
) )
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = ( ca_c_n_loss = (
torch.sum(mask * ca_c_n_loss_per_residue) / (torch.sum(mask) + eps) torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
) )
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev)) (tolerance_factor_hard * gt_stddev))
...@@ -604,7 +613,8 @@ def between_residue_bond_loss( ...@@ -604,7 +613,8 @@ def between_residue_bond_loss(
) )
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = ( c_n_ca_loss = (
torch.sum(mask * c_n_ca_loss_per_residue) / (torch.sum(mask) + eps) torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
) )
c_n_ca_violation_mask = mask * ( c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev) c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
...@@ -619,7 +629,7 @@ def between_residue_bond_loss( ...@@ -619,7 +629,7 @@ def between_residue_bond_loss(
torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) +
torch.nn.functional.pad(per_residue_loss_sum, (1, 0)) torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
) )
# Compute hard violations. # Compute hard violations.
violation_mask = torch.max( violation_mask = torch.max(
torch.stack( torch.stack(
...@@ -627,7 +637,8 @@ def between_residue_bond_loss( ...@@ -627,7 +637,8 @@ def between_residue_bond_loss(
c_n_violation_mask, c_n_violation_mask,
ca_c_n_violation_mask, ca_c_n_violation_mask,
c_n_ca_violation_mask c_n_ca_violation_mask
] ],
dim=-2,
), ),
dim=-2 dim=-2
)[0] )[0]
...@@ -635,7 +646,7 @@ def between_residue_bond_loss( ...@@ -635,7 +646,7 @@ def between_residue_bond_loss(
torch.nn.functional.pad(violation_mask, (0, 1)), torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0)) torch.nn.functional.pad(violation_mask, (1, 0))
) )
return { return {
'c_n_loss_mean': c_n_loss, 'c_n_loss_mean': c_n_loss,
'ca_c_n_loss_mean': ca_c_n_loss, 'ca_c_n_loss_mean': ca_c_n_loss,
...@@ -708,7 +719,7 @@ def between_residue_clash_loss( ...@@ -708,7 +719,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash. # Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot( c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14 residue_index.new_tensor(2.), num_classes=14
) )
c_one_hot = c_one_hot.reshape( c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
...@@ -958,7 +969,7 @@ def find_structural_violations( ...@@ -958,7 +969,7 @@ def find_structural_violations(
atom14_dists_upper_bound=atom14_dists_upper_bound, atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0 tighten_bounds_for_loss=0.0
) )
# Combine them to a single per-residue violation mask (used later for LDDT). # Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = torch.max( per_residue_violations_mask = torch.max(
torch.stack( torch.stack(
...@@ -1255,6 +1266,7 @@ def experimentally_resolved_loss( ...@@ -1255,6 +1266,7 @@ def experimentally_resolved_loss(
min_resolution: float, min_resolution: float,
max_resolution: float, max_resolution: float,
eps: float = 1e-8, eps: float = 1e-8,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask) errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2)) loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2))
...@@ -1268,7 +1280,7 @@ def experimentally_resolved_loss( ...@@ -1268,7 +1280,7 @@ def experimentally_resolved_loss(
return loss return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8): def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, logits,
torch.nn.functional.one_hot(true_msa, num_classes=23) torch.nn.functional.one_hot(true_msa, num_classes=23)
...@@ -1296,24 +1308,48 @@ class AlphaFoldLoss(nn.Module): ...@@ -1296,24 +1308,48 @@ class AlphaFoldLoss(nn.Module):
**self.config.violation, **self.config.violation,
) )
if("atom14_atom_is_ambiguous" not in batch.keys()):
batch.update(feats.build_ambiguity_feats(batch))
if("renamed_atom14_gt_positions" not in out.keys()): if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth( batch.update(compute_renamed_ground_truth(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
)) ))
if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(**batch))
# TODO: Verify that this is correct
batch["backbone_affine_tensor"] = (
batch["rigidgroups_gt_frames"][..., 0, :, :]
)
batch["backbone_affine_mask"] = (
batch["rigidgroups_gt_exists"][..., 0]
)
if("chi_angles_sin_cos" not in batch.keys()):
batch.update(feats.atom37_to_torsion_angles(
**batch,
))
# TODO: Verify that this is correct
batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:]
loss_fns = { loss_fns = {
"distogram": "distogram":
lambda: distogram_loss( lambda: distogram_loss(
out["distogram_logits"], logits=out["distogram_logits"],
**{**batch, **{**batch,
**self.config.distogram}, **self.config.distogram},
), ),
"experimentally_resolved": "experimentally_resolved":
lambda: experimentally_resolved_loss( lambda: experimentally_resolved_loss(
out["experimentally_resolved"], logits=out["experimentally_resolved_logits"],
**{**batch, **{**batch, **self.config.experimentally_resolved},
**self.config.experimentally_resolved},
), ),
"fape": "fape":
lambda: fape_loss( lambda: fape_loss(
...@@ -1323,14 +1359,13 @@ class AlphaFoldLoss(nn.Module): ...@@ -1323,14 +1359,13 @@ class AlphaFoldLoss(nn.Module):
), ),
"lddt": "lddt":
lambda: lddt_loss( lambda: lddt_loss(
out["lddt_logits"], logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"] all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **{**batch, **self.config.lddt},
**self.config.lddt},
), ),
"masked_msa": "masked_msa":
lambda: masked_msa_loss( lambda: masked_msa_loss(
out["masked_msa_logits"], logits=out["masked_msa_logits"],
**{**batch, **{**batch,
**self.config.masked_msa}, **self.config.masked_msa},
), ),
...@@ -1338,8 +1373,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1338,8 +1373,7 @@ class AlphaFoldLoss(nn.Module):
lambda: supervised_chi_loss( lambda: supervised_chi_loss(
out["sm"]["angles"], out["sm"]["angles"],
out["sm"]["unnormalized_angles"], out["sm"]["unnormalized_angles"],
**{**batch, **{**batch, **self.config.supervised_chi},
**self.config.supervised_chi},
), ),
"violation": "violation":
lambda: violation_loss( lambda: violation_loss(
...@@ -1351,6 +1385,9 @@ class AlphaFoldLoss(nn.Module): ...@@ -1351,6 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for k,loss_fn in loss_fns.items(): for k,loss_fn in loss_fns.items():
weight = self.config[k].weight weight = self.config[k].weight
if(weight): if(weight):
cum_loss += weight * loss_fn() print(k)
loss = loss_fn()
print(loss)
cum_loss += weight * loss
return cum_loss return cum_loss
...@@ -142,8 +142,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -142,8 +142,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
be considered batch dimensions. be considered batch dimensions.
Returns: Returns:
The reassembled output of the layer on the inputs. The reassembled output of the layer on the inputs.
""" """
if(not (len(inputs) > 0)): if(not (len(inputs) > 0)):
raise ValueError("Must provide at least one input") raise ValueError("Must provide at least one input")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment