Commit df6b97f2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

First draft of loss class

parent 15895ea9
......@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2]
s = self.linear_out(s)
unnormalized_s = s
# [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2)
norm_denom = torch.sqrt(
......@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
)
s = s / norm_denom
return s
return unnormalized_s, s
class InvariantPointAttention(nn.Module):
......@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
t = t.compose(self.bb_update(s))
# [*, N, 7, 2]
a = self.angle_resnet(s, s_initial)
unnormalized_a, a = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), a, f,
......@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
)
preds = {
"transformations":
"frames":
t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global,
"unnormalized_angles": unnormalized_a,
"angles": a,
"positions": pred_xyz,
}
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import ml_collections
import numpy as np
import torch
......@@ -37,6 +38,13 @@ def softmax_cross_entropy(logits, labels):
return loss
def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits)
log_not_p = torch.nn.functional.logsigmoid(-logits)
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
def torsion_angle_loss(
a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2]
......@@ -102,12 +110,13 @@ def compute_fape(
def backbone_loss(
batch: Dict[str, torch.Tensor],
pred_aff: T,
pred_aff_tensor: torch.Tensor,
clamp_distance: float = 10.,
loss_unit_distance: float = 10.,
) -> torch.Tensor:
gt_aff = T.from_tensor(batch['backbone_affine_tensor'])
backbone_mask = batch['backbone_affine_mask']
pred_aff = T.from_tensor(pred_aff_tensor)
gt_aff = T.from_tensor(batch["backbone_affine_tensor"])
backbone_mask = batch["backbone_affine_mask"]
fape_loss = compute_fape(
pred_aff,
......@@ -138,15 +147,15 @@ def backbone_loss(
fape_loss_unclamped * (1 - use_clamped_fape)
)
return torch.mean(fape_loss, dim=backbone_mask.shape[:-1])
return torch.mean(fape_loss, dim=-1)
def sidechain_loss(
sidechain_frames,
sidechain_atom_pos,
gt_frames,
alt_gt_frames,
gt_exists,
rigidgroups_gt_frames,
rigidgroups_alt_gt_frames,
rigidgroups_gt_exists,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
alt_naming_is_better,
......@@ -176,6 +185,87 @@ def sidechain_loss(
return fape
def fape_loss(
out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
batch, out["sm"]["frames"][-1], **config.backbone
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
{
**batch,
**config.sidechain,
},
)
return (
config.backbone.weight * bb_loss +
config.sidechain.weight * sc_loss
)
def supervised_chi_loss(
angles_sin_cos: torch.Tensor,
unnormalized_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1,
).unsqueeze(-3)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot,
aatype.new_tensor(residue_constants.chi_pi_periodic)
)
true_chi = chi_angles.unsqueeze(-3)
sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = torch.sum(
(sin_cos_true_chi - pred_angles)**2, dim=-1
)
sq_chi_error_shifted = torch.sum(
(sin_cos_true_chi_shifted - pred_angles)**2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = masked_mean(
sq_chi_error, chi_mask.unsqueeze(-3), dim=(-1, -2, -3)
)
loss = 0
loss += chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.)
angle_norm_loss = masked_mean(
norm_error, sequence_mask[..., None, :, None], dim=(-1, -2, -3)
)
loss += angle_norm_weight * angle_norm_loss
return loss
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1. / num_bins
......@@ -192,31 +282,34 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
def lddt_loss(
batch: Dict[str, torch.Tensor],
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.,
num_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
all_atom_pred_pos = batch["sm"]["pred_pos"][-1]
all_atom_true_pos = batch["all_atom_positions"]
all_atom_positions = batch["all_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
logits = batch["predicted_lddt_logits"]
n = all_atom_mask.shape[-1]
ca_pos = residue_constants.atom_order['CA']
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., :, ca_pos, :]
all_atom_true_pos = all_atom_true_pos[..., :, ca_pos, :]
all_atom_positions = all_atom_positions[..., :, ca_pos, :]
all_atom_mask = all_atom_mask[..., :, ca_pos:(ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps +
torch.sum(
(
all_atom_true_pos[..., None] -
all_atom_true_pos[..., None, :]
all_atom_positions[..., None] -
all_atom_positions[..., None, :]
)**2,
dim=-1,
)
......@@ -267,36 +360,44 @@ def lddt_loss(
loss = torch.sum(errors * all_atom_mask) / (torch.sum(mask_ca) + eps)
loss *= (
(batch["resolution"] >= min_resolution) &
(batch["resolution"] <= max_resolution)
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
return loss
def distogram_loss(
pred_distr,
gt,
mask,
min_bin=2.3125, max_bin=21.6875, no_bins=64, eps=1e-6
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
eps=1e-6,
**kwargs,
):
boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=pred_distr.device,
min_bin, max_bin, no_bins - 1, device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(gt[..., None, :] - gt[..., None, :, :]) ** 2, dim=-1, keepdims=True
(
pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]
) ** 2,
dim=-1,
keepdims=True
)
true_bins = torch.sum(dists > sq_breaks, dim=-1)
errors = softmax_cross_entropy(
pred_distr,
logits,
torch.nn.functional.one_hot(true_bins, num_bins),
)
square_mask = mask[..., None] * mask[..., None, :]
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
mean = (
torch.sum(errors * square_mask, dim=(-1, -2)) /
......@@ -417,7 +518,7 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[..., 1:] == residue_constants.resname_to_idx['PRO']
aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
)
gt_length = (
(~next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
......@@ -609,7 +710,7 @@ def between_residue_clash_loss(
dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names['CYS']
cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index('SG')
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(
......@@ -768,18 +869,20 @@ def within_residue_violations(
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
config: ml_collections.ConfigDict
violation_tolerance_factor: float,
clash_overlap_tolerance: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'],
residue_index=batch['residue_index'],
aatype=batch['aatype'],
tolerance_factor_soft=config.violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
aatype=batch["aatype"],
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor
)
# Compute the Van der Waals radius for every atom
......@@ -793,31 +896,31 @@ def find_structural_violations(
atomtype_radius
)
atom14_atom_radius = (
batch['atom14_atom_exists'] *
atomtype_radius[batch['residx_atom14_to_atom37']]
batch["atom14_atom_exists"] *
atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch['residue_index'],
overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance
residue_index=batch["residue_index"],
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor
)
atom14_dists_lower_bound = restype_atom14_bounds['lower_bound'][
batch['aatype']
atom14_dists_lower_bound = restype_atom14_bounds["lower_bound"][
batch["aatype"]
]
atom14_dists_upper_bound = restype_atom14_bounds['upper_bound'][
batch['aatype']
atom14_dists_upper_bound = restype_atom14_bounds["upper_bound"][
batch["aatype"]
]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
atom14_dists_lower_bound
......@@ -827,7 +930,7 @@ def find_structural_violations(
)
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'],
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0
......@@ -837,12 +940,12 @@ def find_structural_violations(
per_residue_violations_mask = torch.max(
torch.stack(
[
connection_violations['per_residue_violation_mask'],
connection_violations["per_residue_violation_mask"],
torch.max(
between_residue_clashes['per_atom_clash_mask'], dim=-1
between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0],
torch.max(
residue_violations['per_atom_violations'], dim=-1
residue_violations["per_atom_violations"], dim=-1
)[0],
],
dim=-1,
......@@ -853,27 +956,27 @@ def find_structural_violations(
return {
'between_residues': {
'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # ()
connection_violations["c_n_loss_mean"], # ()
'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # ()
connection_violations["ca_c_n_loss_mean"], # ()
'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # ()
connection_violations["c_n_ca_loss_mean"], # ()
'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N)
connection_violations["per_residue_loss_sum"], # (N)
'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N)
connection_violations["per_residue_violation_mask"], # (N)
'clashes_mean_loss':
between_residue_clashes['mean_loss'], # ()
between_residue_clashes["mean_loss"], # ()
'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14)
between_residue_clashes["per_atom_loss_sum"], # (N, 14)
'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14)
between_residue_clashes["per_atom_clash_mask"], # (N, 14)
},
'within_residues': {
'per_atom_loss_sum':
residue_violations['per_atom_loss_sum'], # (N, 14)
residue_violations["per_atom_loss_sum"], # (N, 14)
'per_atom_violations':
residue_violations['per_atom_violations'], # (N, 14),
residue_violations["per_atom_violations"], # (N, 14),
},
'total_per_residue_violations_mask':
per_residue_violations_mask, # (N)
......@@ -943,35 +1046,35 @@ def compute_violation_metrics(
ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'],
residue_index=batch['residue_index']
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"]
)
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = masked_mean(
batch['seq_mask'],
violations['between_residues'][
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"][
'connections_per_residue_violation_mask'
],
dim=-1,
)
ret['violations_between_residue_clash'] = masked_mean(
mask=batch['seq_mask'],
ret["violations_between_residue_clash"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations['between_residues']['clashes_per_atom_clash_mask'],
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1
)[0],
dim=-1,
)
ret['violations_within_residue'] = masked_mean(
mask=batch['seq_mask'],
ret["violations_within_residue"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations['within_residues']['per_atom_violations'], dim=-1
violations["within_residues"]["per_atom_violations"], dim=-1
)[0],
dim=-1,
)
ret['violations_per_residue'] = masked_mean(
mask=batch['seq_mask'],
value=violations['total_per_residue_violations_mask'],
ret["violations_per_residue"] = masked_mean(
mask=batch["seq_mask"],
value=violations["total_per_residue_violations_mask"],
dim=-1,
)
return ret
......@@ -994,6 +1097,27 @@ def compute_violation_metrics_np(
return tree_map(to_np, out, torch.Tensor)
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
eps=1e-6,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] +
violations["between_residues"]["angles_ca_c_n_loss_mean"] +
violations["between_residues"]["angles_c_n_ca_loss_mean"] +
l_clash
)
return loss
def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
......@@ -1038,7 +1162,7 @@ def compute_renamed_ground_truth(
)
)
atom14_gt_positions = batch['atom14_gt_positions']
atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt(
eps +
torch.sum(
......@@ -1050,7 +1174,7 @@ def compute_renamed_ground_truth(
)
)
atom14_alt_gt_positions = batch['atom14_alt_gt_positions']
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt(
eps +
torch.sum(
......@@ -1065,8 +1189,8 @@ def compute_renamed_ground_truth(
lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2)
atom14_gt_exists = batch['atom14_gt_exists']
atom14_atom_is_ambiguous = batch['atom14_atom_is_ambiguous']
atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = (
atom14_gt_exists[..., None, :, None] *
atom14_atom_is_ambiguous[..., None, :, None] *
......@@ -1089,13 +1213,13 @@ def compute_renamed_ground_truth(
renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[..., None]) * atom14_gt_exists +
alt_naming_is_better[..., None] * batch['atom14_alt_gt_exists']
alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"]
)
return {
'alt_naming_is_better': alt_naming_is_better,
'renamed_atom14_gt_positions': renamed_atom14_gt_positions,
'renamed_atom14_gt_exists': renamed_atom14_gt_mask,
"alt_naming_is_better": alt_naming_is_better,
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
}
......@@ -1103,9 +1227,105 @@ def experimentally_resolved_loss(
logits: torch.Tensor,
atom37_atom_exists: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
min_resolution: float,
max_resolution: float,
eps: float = 1e-8,
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2))
loss = loss_num / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss *= (
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8):
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23,
)
loss = (
torch.sum(errors * bert_mask, dim=(-1, -2)) /
(eps + torch.sum(bert_mask, dim=(-1, -2)))
)
return loss
class AlphaFoldLoss(nn.Module):
""" Aggregation of the various losses described in the supplement """
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
cum_loss = 0
if("violation" not in out.keys() and self.config.violation.weight):
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
loss_fns = {
"distogram":
lambda: distogram_loss(
logits=out["distogram_logits"],
{**batch,
**self.config.distogram},
),
"experimentally_resolved":
lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved"],
{**batch,
**self.config.experimentally_resolved},
),
"fape":
lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt":
lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"]
{**batch,
**self.config.lddt},
),
"masked_msa":
lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
{**batch,
**self.config.masked_msa},
),
"supervised_chi":
lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
{**batch,
**self.config.supervised_chi},
),
"violation":
lambda: violation_loss(
out["violation"],
**batch,
),
}
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
cum_loss += weight * loss_fn()
return cum_loss
......@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.,
},
"fape": {
"backbone": {
"clamp_distance": 10.,
"loss_unit_distance": 10.,
"weight": 0.5,
}
"sidechain": {
"clamp_distance": 10.,
"length_scale": 10.,
"weight": 0.5,
}
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.,
"num_bins": 50,
"eps": 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": 1e-6,
"weight": 1.0,
},
"violation": {
"eps": 1e-6,
"weight": 0.,
},
},
})
......@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
a = torch.rand((batch_size, n, c_s))
a_initial = torch.rand((batch_size, n, c_s))
a = ar(a, a_initial)
_, a = ar(a, a_initial)
self.assertTrue(a.shape == (batch_size, n, no_angles, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment