Commit ab0c6977 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Continue debugging loss functions, remove in-place ops

parent 85c0a9a9
......@@ -113,6 +113,7 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri)
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
......
......@@ -94,7 +94,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
m = chunk_layer(
self._transition,
inp,
......@@ -132,6 +132,7 @@ class EvoformerBlock(nn.Module):
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf,
)
......
......@@ -108,7 +108,7 @@ class AlphaFold(nn.Module):
def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
......@@ -162,7 +162,7 @@ class AlphaFold(nn.Module):
z,
template_mask=batch["template_mask"]
)
t *= torch.sum(batch["template_mask"]) > 0
t = t * torch.sum(batch["template_mask"]) > 0
return {
"template_angle_embedding": a,
......@@ -297,7 +297,7 @@ class AlphaFold(nn.Module):
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
......@@ -312,7 +312,7 @@ class AlphaFold(nn.Module):
)
# [*, N, N, C_z]
z += template_embeds["template_pair_embedding"]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
......
......@@ -125,7 +125,7 @@ class MSAAttention(nn.Module):
"v_x": m,
"biases": biases
}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
m = chunk_layer(
self.mha,
mha_inputs,
......@@ -142,7 +142,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
"""
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9):
"""
Args:
c_m:
......@@ -162,6 +162,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
no_heads,
pair_bias=True,
c_z=c_z,
chunk_size=chunk_size,
inf=inf,
)
......@@ -259,7 +260,7 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q *= self.c_hidden ** (-0.5)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
......@@ -274,7 +275,7 @@ class MSAColumnGlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
......@@ -318,7 +319,7 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m,
"mask": mask,
}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
m = chunk_layer(
self.global_attention,
mha_input,
......
......@@ -83,7 +83,7 @@ class OuterProductMean(nn.Module):
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
......@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z]
outer /= self.eps + norm
outer = outer / self.eps + norm
return outer
......
......@@ -73,7 +73,7 @@ class PairTransition(nn.Module):
z = self.layer_norm(z)
inp = {"z": z, "mask": mask}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
z = chunk_layer(
self._transition,
inp,
......
......@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
a = a * norm
if(biases is not None):
for b in biases:
a += b
a = a + b
a = self.softmax(a)
# [*, H, Q, C_hidden]
......
......@@ -129,10 +129,11 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2]
s = self.linear_out(s)
unnormalized_s = s
# [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2)
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True),
......@@ -295,8 +296,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1. / (3 * self.c_hidden))
a += math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1)
a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1)
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -307,7 +308,9 @@ class InvariantPointAttention(nn.Module):
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
head_weights *= math.sqrt(1. / (3 * (self.no_qk_points * 9. / 2)))
head_weights = (
head_weights * math.sqrt(1. / (3 * (self.no_qk_points * 9. / 2)))
)
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
......@@ -319,8 +322,8 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1)
a += pt_att
a += square_mask.unsqueeze(-3)
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
......@@ -510,7 +513,7 @@ def _frames_and_literature_positions_to_atom14_pos(
# [*, N, 14, 3]
lit_positions = lit_positions[f, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions *= atom_mask
pred_positions = pred_positions * atom_mask
return pred_positions
......
......@@ -108,7 +108,7 @@ class TemplatePointwiseAttention(nn.Module):
"v_x": t,
"biases": [bias],
}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
z = chunk_layer(
self.mha,
mha_inputs,
......
......@@ -102,7 +102,7 @@ class TriangleAttention(nn.Module):
"v_x": x,
"biases": [mask_bias, triangle_bias],
}
if(not self.training and self.chunk_size is not None):
if(self.chunk_size is not None):
x = chunk_layer(
self.mha,
mha_inputs,
......
......@@ -70,5 +70,6 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args)
args = wrap(args)
return args
......@@ -158,15 +158,16 @@ def atom14_to_atom37(atom14, batch):
def atom37_to_torsion_angles(
aatype: torch.Tensor,
all_atom_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float = 1e-8,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Args:
aatype:
[*, N_res] residue indices
all_atom_pos:
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
......@@ -183,28 +184,32 @@ def atom37_to_torsion_angles(
"""
aatype = torch.clamp(aatype, max=20)
pad = all_atom_pos.new_zeros([*all_atom_pos.shape[:-3], 1, 37, 3])
prev_all_atom_pos = torch.cat([pad, all_atom_pos[..., :-1, :, :]], dim=-3)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_pos[..., 1:3, :],
all_atom_pos[..., :2, :]
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
)
phi_atom_pos = torch.cat(
[
prev_all_atom_pos[..., 2:3, :],
all_atom_pos[..., :3, :]
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
)
psi_atom_pos = torch.cat(
[
all_atom_pos[..., :3, :],
all_atom_pos[..., 4:5, :]
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
)
......@@ -227,7 +232,7 @@ def atom37_to_torsion_angles(
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_pos, atom_indices, -2, len(atom_indices.shape[:-2])
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
......@@ -335,9 +340,9 @@ def atom37_to_frames(
device=aatype.device,
requires_grad=False
)
restype_rigidgroup_mask[:, 0] = 1
restype_rigidgroup_mask[:, 3] = 1
restype_rigidgroup_mask[:20, 4:] = (
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
......
......@@ -22,6 +22,7 @@ from typing import Dict, Optional
from openfold.np import residue_constants
from openfold.model.primitives import Linear
from openfold.utils import feats
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
tree_map,
......@@ -150,7 +151,9 @@ def backbone_loss(
unclamped_fape_loss * (1 - use_clamped_fape)
)
return torch.mean(fape_loss, dim=-1)
# Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=0)
return fape_loss
def sidechain_loss(
......@@ -173,10 +176,9 @@ def sidechain_loss(
rigidgroups_alt_gt_frames
)
batch_dims = sidechain_frames.shape[:-5]
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(
*batch_dims, -1, 4, 4
)
......@@ -240,7 +242,7 @@ def supervised_chi_loss(
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles: torch.Tensor,
chi_angles_sin_cos: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
......@@ -256,24 +258,24 @@ def supervised_chi_loss(
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
true_chi = chi_angles
sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
true_chi = chi_angles_sin_cos.unsqueeze(-4)
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum(
(sin_cos_true_chi - pred_angles)**2, dim=-1
(true_chi - pred_angles)**2, dim=-1
)
sq_chi_error_shifted = torch.sum(
(sin_cos_true_chi_shifted - pred_angles)**2, dim=-1
(true_chi_shifted - pred_angles)**2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean(
chi_mask, sq_chi_error, dim=(-1, -2)
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
loss = 0
......@@ -283,8 +285,11 @@ def supervised_chi_loss(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
angle_norm_loss = masked_mean(
seq_mask[..., None], norm_error, dim=(-1, -2)
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
)
loss += angle_norm_weight * angle_norm_loss
......@@ -377,10 +382,10 @@ def lddt_loss(
)
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = (
torch.sum(errors * all_atom_mask) / (torch.sum(all_atom_mask) + 1e-8)
torch.sum(errors * all_atom_mask, dim=-1) /
(eps + torch.sum(all_atom_mask, dim=-1))
)
loss *= (
......@@ -483,10 +488,10 @@ def tm_score(
def between_residue_bond_loss(
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
pred_atom_mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
aatype: torch.Tensor, # (N)
pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
pred_atom_mask: torch.Tensor, # (*, N, 37/14)
residue_index: torch.Tensor, # (*, N)
aatype: torch.Tensor, # (*, N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0,
eps=1e-6,
......@@ -561,7 +566,10 @@ def between_residue_bond_loss(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = torch.sum(mask * c_n_loss_per_residue) / (torch.sum(mask) + eps)
c_n_loss = (
torch.sum(mask * c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
)
......@@ -589,7 +597,8 @@ def between_residue_bond_loss(
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = (
torch.sum(mask * ca_c_n_loss_per_residue) / (torch.sum(mask) + eps)
torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
......@@ -604,7 +613,8 @@ def between_residue_bond_loss(
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = (
torch.sum(mask * c_n_ca_loss_per_residue) / (torch.sum(mask) + eps)
torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) /
(torch.sum(mask, dim=-1) + eps)
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
......@@ -627,7 +637,8 @@ def between_residue_bond_loss(
c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask
]
],
dim=-2,
),
dim=-2
)[0]
......@@ -708,7 +719,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
residue_index.new_tensor(2.), num_classes=14
)
c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
......@@ -1255,6 +1266,7 @@ def experimentally_resolved_loss(
min_resolution: float,
max_resolution: float,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2))
......@@ -1268,7 +1280,7 @@ def experimentally_resolved_loss(
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8):
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23)
......@@ -1296,24 +1308,48 @@ class AlphaFoldLoss(nn.Module):
**self.config.violation,
)
if("atom14_atom_is_ambiguous" not in batch.keys()):
batch.update(feats.build_ambiguity_feats(batch))
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(**batch))
# TODO: Verify that this is correct
batch["backbone_affine_tensor"] = (
batch["rigidgroups_gt_frames"][..., 0, :, :]
)
batch["backbone_affine_mask"] = (
batch["rigidgroups_gt_exists"][..., 0]
)
if("chi_angles_sin_cos" not in batch.keys()):
batch.update(feats.atom37_to_torsion_angles(
**batch,
))
# TODO: Verify that this is correct
batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:]
loss_fns = {
"distogram":
lambda: distogram_loss(
out["distogram_logits"],
logits=out["distogram_logits"],
**{**batch,
**self.config.distogram},
),
"experimentally_resolved":
lambda: experimentally_resolved_loss(
out["experimentally_resolved"],
**{**batch,
**self.config.experimentally_resolved},
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape":
lambda: fape_loss(
......@@ -1323,14 +1359,13 @@ class AlphaFoldLoss(nn.Module):
),
"lddt":
lambda: lddt_loss(
out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"]
**{**batch,
**self.config.lddt},
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa":
lambda: masked_msa_loss(
out["masked_msa_logits"],
logits=out["masked_msa_logits"],
**{**batch,
**self.config.masked_msa},
),
......@@ -1338,8 +1373,7 @@ class AlphaFoldLoss(nn.Module):
lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch,
**self.config.supervised_chi},
**{**batch, **self.config.supervised_chi},
),
"violation":
lambda: violation_loss(
......@@ -1351,6 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
cum_loss += weight * loss_fn()
print(k)
loss = loss_fn()
print(loss)
cum_loss += weight * loss
return cum_loss
......@@ -143,7 +143,6 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
Returns:
The reassembled output of the layer on the inputs.
"""
if(not (len(inputs) > 0)):
raise ValueError("Must provide at least one input")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment