"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "43b3c006a394c9e5f37a72ffb4caf744ea9a53cb"
Commit df6b97f2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

First draft of loss class

parent 15895ea9
...@@ -129,6 +129,8 @@ class AngleResnet(nn.Module): ...@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2] # [*, no_angles * 2]
s = self.linear_out(s) s = self.linear_out(s)
unnormalized_s = s
# [*, no_angles, 2] # [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2) s = s.view(*s.shape[:-1], -1, 2)
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
...@@ -139,7 +141,7 @@ class AngleResnet(nn.Module): ...@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
) )
s = s / norm_denom s = s / norm_denom
return s return unnormalized_s, s
class InvariantPointAttention(nn.Module): class InvariantPointAttention(nn.Module):
...@@ -723,7 +725,7 @@ class StructureModule(nn.Module): ...@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
t = t.compose(self.bb_update(s)) t = t.compose(self.bb_update(s))
# [*, N, 7, 2] # [*, N, 7, 2]
a = self.angle_resnet(s, s_initial) unnormalized_a, a = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames( all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), a, f, t.scale_translation(self.trans_scale_factor), a, f,
...@@ -735,8 +737,10 @@ class StructureModule(nn.Module): ...@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
) )
preds = { preds = {
"transformations": "frames":
t.scale_translation(self.trans_scale_factor).to_4x4(), t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global,
"unnormalized_angles": unnormalized_a,
"angles": a, "angles": a,
"positions": pred_xyz, "positions": pred_xyz,
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import ml_collections import ml_collections
import numpy as np import numpy as np
import torch import torch
...@@ -37,6 +38,13 @@ def softmax_cross_entropy(logits, labels): ...@@ -37,6 +38,13 @@ def softmax_cross_entropy(logits, labels):
return loss return loss
def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits)
log_not_p = torch.nn.functional.logsigmoid(-logits)
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
def torsion_angle_loss( def torsion_angle_loss(
a, # [*, N, 7, 2] a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2] a_gt, # [*, N, 7, 2]
...@@ -102,12 +110,13 @@ def compute_fape( ...@@ -102,12 +110,13 @@ def compute_fape(
def backbone_loss( def backbone_loss(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
pred_aff: T, pred_aff_tensor: torch.Tensor,
clamp_distance: float = 10., clamp_distance: float = 10.,
loss_unit_distance: float = 10., loss_unit_distance: float = 10.,
) -> torch.Tensor: ) -> torch.Tensor:
gt_aff = T.from_tensor(batch['backbone_affine_tensor']) pred_aff = T.from_tensor(pred_aff_tensor)
backbone_mask = batch['backbone_affine_mask'] gt_aff = T.from_tensor(batch["backbone_affine_tensor"])
backbone_mask = batch["backbone_affine_mask"]
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
...@@ -138,15 +147,15 @@ def backbone_loss( ...@@ -138,15 +147,15 @@ def backbone_loss(
fape_loss_unclamped * (1 - use_clamped_fape) fape_loss_unclamped * (1 - use_clamped_fape)
) )
return torch.mean(fape_loss, dim=backbone_mask.shape[:-1]) return torch.mean(fape_loss, dim=-1)
def sidechain_loss( def sidechain_loss(
sidechain_frames, sidechain_frames,
sidechain_atom_pos, sidechain_atom_pos,
gt_frames, rigidgroups_gt_frames,
alt_gt_frames, rigidgroups_alt_gt_frames,
gt_exists, rigidgroups_gt_exists,
renamed_atom14_gt_positions, renamed_atom14_gt_positions,
renamed_atom14_gt_exists, renamed_atom14_gt_exists,
alt_naming_is_better, alt_naming_is_better,
...@@ -174,7 +183,88 @@ def sidechain_loss( ...@@ -174,7 +183,88 @@ def sidechain_loss(
) )
return fape return fape
def fape_loss(
out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
batch, out["sm"]["frames"][-1], **config.backbone
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
{
**batch,
**config.sidechain,
},
)
return (
config.backbone.weight * bb_loss +
config.sidechain.weight * sc_loss
)
def supervised_chi_loss(
angles_sin_cos: torch.Tensor,
unnormalized_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1,
).unsqueeze(-3)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot,
aatype.new_tensor(residue_constants.chi_pi_periodic)
)
true_chi = chi_angles.unsqueeze(-3)
sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi
sq_chi_error = torch.sum(
(sin_cos_true_chi - pred_angles)**2, dim=-1
)
sq_chi_error_shifted = torch.sum(
(sin_cos_true_chi_shifted - pred_angles)**2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = masked_mean(
sq_chi_error, chi_mask.unsqueeze(-3), dim=(-1, -2, -3)
)
loss = 0
loss += chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.)
angle_norm_loss = masked_mean(
norm_error, sequence_mask[..., None, :, None], dim=(-1, -2, -3)
)
loss += angle_norm_weight * angle_norm_loss
return loss
def compute_plddt(logits: torch.Tensor) -> torch.Tensor: def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1] num_bins = logits.shape[-1]
...@@ -192,31 +282,34 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor: ...@@ -192,31 +282,34 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
def lddt_loss( def lddt_loss(
batch: Dict[str, torch.Tensor], logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15., cutoff: float = 15.,
num_bins: int = 50, num_bins: int = 50,
min_resolution: float = 0.1, min_resolution: float = 0.1,
max_resolution: float = 3.0, max_resolution: float = 3.0,
eps: float = 1e-10, eps: float = 1e-10,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
all_atom_pred_pos = batch["sm"]["pred_pos"][-1] all_atom_positions = batch["all_atom_positions"]
all_atom_true_pos = batch["all_atom_positions"]
all_atom_mask = batch["all_atom_mask"] all_atom_mask = batch["all_atom_mask"]
logits = batch["predicted_lddt_logits"]
n = all_atom_mask.shape[-1] n = all_atom_mask.shape[-1]
ca_pos = residue_constants.atom_order['CA'] ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., :, ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., :, ca_pos, :]
all_atom_true_pos = all_atom_true_pos[..., :, ca_pos, :] all_atom_positions = all_atom_positions[..., :, ca_pos, :]
all_atom_mask = all_atom_mask[..., :, ca_pos:(ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., :, ca_pos:(ca_pos + 1)] # keep dim
dmat_true = torch.sqrt( dmat_true = torch.sqrt(
eps + eps +
torch.sum( torch.sum(
( (
all_atom_true_pos[..., None] - all_atom_positions[..., None] -
all_atom_true_pos[..., None, :] all_atom_positions[..., None, :]
)**2, )**2,
dim=-1, dim=-1,
) )
...@@ -267,36 +360,44 @@ def lddt_loss( ...@@ -267,36 +360,44 @@ def lddt_loss(
loss = torch.sum(errors * all_atom_mask) / (torch.sum(mask_ca) + eps) loss = torch.sum(errors * all_atom_mask) / (torch.sum(mask_ca) + eps)
loss *= ( loss *= (
(batch["resolution"] >= min_resolution) & (resolution >= min_resolution) &
(batch["resolution"] <= max_resolution) (resolution <= max_resolution)
) )
return loss return loss
def distogram_loss( def distogram_loss(
pred_distr, logits,
gt, pseudo_beta,
mask, pseudo_beta_mask,
min_bin=2.3125, max_bin=21.6875, no_bins=64, eps=1e-6 min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
eps=1e-6,
**kwargs,
): ):
boundaries = torch.linspace( boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=pred_distr.device, min_bin, max_bin, no_bins - 1, device=logits.device,
) )
boundaries = boundaries ** 2 boundaries = boundaries ** 2
dists = torch.sum( dists = torch.sum(
(gt[..., None, :] - gt[..., None, :, :]) ** 2, dim=-1, keepdims=True (
pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]
) ** 2,
dim=-1,
keepdims=True
) )
true_bins = torch.sum(dists > sq_breaks, dim=-1) true_bins = torch.sum(dists > sq_breaks, dim=-1)
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
pred_distr, logits,
torch.nn.functional.one_hot(true_bins, num_bins), torch.nn.functional.one_hot(true_bins, num_bins),
) )
square_mask = mask[..., None] * mask[..., None, :] square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
mean = ( mean = (
torch.sum(errors * square_mask, dim=(-1, -2)) / torch.sum(errors * square_mask, dim=(-1, -2)) /
...@@ -417,7 +518,7 @@ def between_residue_bond_loss( ...@@ -417,7 +518,7 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring. # The C-N bond to proline has slightly different length because of the ring.
next_is_proline = ( next_is_proline = (
aatype[..., 1:] == residue_constants.resname_to_idx['PRO'] aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
) )
gt_length = ( gt_length = (
(~next_is_proline) * residue_constants.between_res_bond_length_c_n[0] (~next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
...@@ -609,7 +710,7 @@ def between_residue_clash_loss( ...@@ -609,7 +710,7 @@ def between_residue_clash_loss(
dists_mask *= (1. - c_n_bonds) dists_mask *= (1. - c_n_bonds)
# Disulfide bridge between two cysteines is no clash. # Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names['CYS'] cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index('SG') cys_sg_idx = cys.index('SG')
cys_sg_idx = residue_index.new_tensor(cys_sg_idx) cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape( cys_sg_idx = cys_sg_idx.reshape(
...@@ -768,18 +869,20 @@ def within_residue_violations( ...@@ -768,18 +869,20 @@ def within_residue_violations(
def find_structural_violations( def find_structural_violations(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor, atom14_pred_positions: torch.Tensor,
config: ml_collections.ConfigDict violation_tolerance_factor: float,
clash_overlap_tolerance: float,
**kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Computes several checks for structural violations.""" """Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles. # Compute between residue backbone violations of bonds and angles.
connection_violations = between_residue_bond_loss( connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions, pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'], pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch['residue_index'], residue_index=batch["residue_index"],
aatype=batch['aatype'], aatype=batch["aatype"],
tolerance_factor_soft=config.violation_tolerance_factor, tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=config.violation_tolerance_factor tolerance_factor_hard=violation_tolerance_factor
) )
# Compute the Van der Waals radius for every atom # Compute the Van der Waals radius for every atom
...@@ -793,31 +896,31 @@ def find_structural_violations( ...@@ -793,31 +896,31 @@ def find_structural_violations(
atomtype_radius atomtype_radius
) )
atom14_atom_radius = ( atom14_atom_radius = (
batch['atom14_atom_exists'] * batch["atom14_atom_exists"] *
atomtype_radius[batch['residx_atom14_to_atom37']] atomtype_radius[batch["residx_atom14_to_atom37"]]
) )
# Compute the between residue clash loss. # Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss( between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'], atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius, atom14_atom_radius=atom14_atom_radius,
residue_index=batch['residue_index'], residue_index=batch["residue_index"],
overlap_tolerance_soft=config.clash_overlap_tolerance, overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance overlap_tolerance_hard=clash_overlap_tolerance
) )
# Compute all within-residue violations (clashes, # Compute all within-residue violations (clashes,
# bond length and angle violations). # bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=config.clash_overlap_tolerance, overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=config.violation_tolerance_factor bond_length_tolerance_factor=violation_tolerance_factor
) )
atom14_dists_lower_bound = restype_atom14_bounds['lower_bound'][ atom14_dists_lower_bound = restype_atom14_bounds["lower_bound"][
batch['aatype'] batch["aatype"]
] ]
atom14_dists_upper_bound = restype_atom14_bounds['upper_bound'][ atom14_dists_upper_bound = restype_atom14_bounds["upper_bound"][
batch['aatype'] batch["aatype"]
] ]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor( atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
atom14_dists_lower_bound atom14_dists_lower_bound
...@@ -827,7 +930,7 @@ def find_structural_violations( ...@@ -827,7 +930,7 @@ def find_structural_violations(
) )
residue_violations = within_residue_violations( residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch['atom14_atom_exists'], atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound, atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound, atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0 tighten_bounds_for_loss=0.0
...@@ -837,12 +940,12 @@ def find_structural_violations( ...@@ -837,12 +940,12 @@ def find_structural_violations(
per_residue_violations_mask = torch.max( per_residue_violations_mask = torch.max(
torch.stack( torch.stack(
[ [
connection_violations['per_residue_violation_mask'], connection_violations["per_residue_violation_mask"],
torch.max( torch.max(
between_residue_clashes['per_atom_clash_mask'], dim=-1 between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0], )[0],
torch.max( torch.max(
residue_violations['per_atom_violations'], dim=-1 residue_violations["per_atom_violations"], dim=-1
)[0], )[0],
], ],
dim=-1, dim=-1,
...@@ -853,27 +956,27 @@ def find_structural_violations( ...@@ -853,27 +956,27 @@ def find_structural_violations(
return { return {
'between_residues': { 'between_residues': {
'bonds_c_n_loss_mean': 'bonds_c_n_loss_mean':
connection_violations['c_n_loss_mean'], # () connection_violations["c_n_loss_mean"], # ()
'angles_ca_c_n_loss_mean': 'angles_ca_c_n_loss_mean':
connection_violations['ca_c_n_loss_mean'], # () connection_violations["ca_c_n_loss_mean"], # ()
'angles_c_n_ca_loss_mean': 'angles_c_n_ca_loss_mean':
connection_violations['c_n_ca_loss_mean'], # () connection_violations["c_n_ca_loss_mean"], # ()
'connections_per_residue_loss_sum': 'connections_per_residue_loss_sum':
connection_violations['per_residue_loss_sum'], # (N) connection_violations["per_residue_loss_sum"], # (N)
'connections_per_residue_violation_mask': 'connections_per_residue_violation_mask':
connection_violations['per_residue_violation_mask'], # (N) connection_violations["per_residue_violation_mask"], # (N)
'clashes_mean_loss': 'clashes_mean_loss':
between_residue_clashes['mean_loss'], # () between_residue_clashes["mean_loss"], # ()
'clashes_per_atom_loss_sum': 'clashes_per_atom_loss_sum':
between_residue_clashes['per_atom_loss_sum'], # (N, 14) between_residue_clashes["per_atom_loss_sum"], # (N, 14)
'clashes_per_atom_clash_mask': 'clashes_per_atom_clash_mask':
between_residue_clashes['per_atom_clash_mask'], # (N, 14) between_residue_clashes["per_atom_clash_mask"], # (N, 14)
}, },
'within_residues': { 'within_residues': {
'per_atom_loss_sum': 'per_atom_loss_sum':
residue_violations['per_atom_loss_sum'], # (N, 14) residue_violations["per_atom_loss_sum"], # (N, 14)
'per_atom_violations': 'per_atom_violations':
residue_violations['per_atom_violations'], # (N, 14), residue_violations["per_atom_violations"], # (N, 14),
}, },
'total_per_residue_violations_mask': 'total_per_residue_violations_mask':
per_residue_violations_mask, # (N) per_residue_violations_mask, # (N)
...@@ -943,35 +1046,35 @@ def compute_violation_metrics( ...@@ -943,35 +1046,35 @@ def compute_violation_metrics(
ret = {} ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations( extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions, pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch['atom14_atom_exists'], pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch['residue_index'] residue_index=batch["residue_index"]
) )
ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret['violations_between_residue_bond'] = masked_mean( ret["violations_between_residue_bond"] = masked_mean(
batch['seq_mask'], batch["seq_mask"],
violations['between_residues'][ violations["between_residues"][
'connections_per_residue_violation_mask' 'connections_per_residue_violation_mask'
], ],
dim=-1, dim=-1,
) )
ret['violations_between_residue_clash'] = masked_mean( ret["violations_between_residue_clash"] = masked_mean(
mask=batch['seq_mask'], mask=batch["seq_mask"],
value=torch.max( value=torch.max(
violations['between_residues']['clashes_per_atom_clash_mask'], violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1 dim=-1
)[0], )[0],
dim=-1, dim=-1,
) )
ret['violations_within_residue'] = masked_mean( ret["violations_within_residue"] = masked_mean(
mask=batch['seq_mask'], mask=batch["seq_mask"],
value=torch.max( value=torch.max(
violations['within_residues']['per_atom_violations'], dim=-1 violations["within_residues"]["per_atom_violations"], dim=-1
)[0], )[0],
dim=-1, dim=-1,
) )
ret['violations_per_residue'] = masked_mean( ret["violations_per_residue"] = masked_mean(
mask=batch['seq_mask'], mask=batch["seq_mask"],
value=violations['total_per_residue_violations_mask'], value=violations["total_per_residue_violations_mask"],
dim=-1, dim=-1,
) )
return ret return ret
...@@ -994,6 +1097,27 @@ def compute_violation_metrics_np( ...@@ -994,6 +1097,27 @@ def compute_violation_metrics_np(
return tree_map(to_np, out, torch.Tensor) return tree_map(to_np, out, torch.Tensor)
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
eps=1e-6,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] +
violations["between_residues"]["angles_ca_c_n_loss_mean"] +
violations["between_residues"]["angles_c_n_ca_loss_mean"] +
l_clash
)
return loss
def compute_renamed_ground_truth( def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor, atom14_pred_positions: torch.Tensor,
...@@ -1038,7 +1162,7 @@ def compute_renamed_ground_truth( ...@@ -1038,7 +1162,7 @@ def compute_renamed_ground_truth(
) )
) )
atom14_gt_positions = batch['atom14_gt_positions'] atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt( gt_dists = torch.sqrt(
eps + eps +
torch.sum( torch.sum(
...@@ -1050,7 +1174,7 @@ def compute_renamed_ground_truth( ...@@ -1050,7 +1174,7 @@ def compute_renamed_ground_truth(
) )
) )
atom14_alt_gt_positions = batch['atom14_alt_gt_positions'] atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt( alt_gt_dists = torch.sqrt(
eps + eps +
torch.sum( torch.sum(
...@@ -1065,8 +1189,8 @@ def compute_renamed_ground_truth( ...@@ -1065,8 +1189,8 @@ def compute_renamed_ground_truth(
lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2) lddt = torch.sqrt(eps + (pred_dists - gt_dists)**2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2) alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists)**2)
atom14_gt_exists = batch['atom14_gt_exists'] atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch['atom14_atom_is_ambiguous'] atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = ( mask = (
atom14_gt_exists[..., None, :, None] * atom14_gt_exists[..., None, :, None] *
atom14_atom_is_ambiguous[..., None, :, None] * atom14_atom_is_ambiguous[..., None, :, None] *
...@@ -1089,13 +1213,13 @@ def compute_renamed_ground_truth( ...@@ -1089,13 +1213,13 @@ def compute_renamed_ground_truth(
renamed_atom14_gt_mask = ( renamed_atom14_gt_mask = (
(1. - alt_naming_is_better[..., None]) * atom14_gt_exists + (1. - alt_naming_is_better[..., None]) * atom14_gt_exists +
alt_naming_is_better[..., None] * batch['atom14_alt_gt_exists'] alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"]
) )
return { return {
'alt_naming_is_better': alt_naming_is_better, "alt_naming_is_better": alt_naming_is_better,
'renamed_atom14_gt_positions': renamed_atom14_gt_positions, "renamed_atom14_gt_positions": renamed_atom14_gt_positions,
'renamed_atom14_gt_exists': renamed_atom14_gt_mask, "renamed_atom14_gt_exists": renamed_atom14_gt_mask,
} }
...@@ -1103,9 +1227,105 @@ def experimentally_resolved_loss( ...@@ -1103,9 +1227,105 @@ def experimentally_resolved_loss(
logits: torch.Tensor, logits: torch.Tensor,
atom37_atom_exists: torch.Tensor, atom37_atom_exists: torch.Tensor,
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
min_resolution: float,
max_resolution: float,
eps: float = 1e-8, eps: float = 1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask) errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2)) loss_num = torch.sum(errors * atom37_atom_exists, dim=(-1, -2))
loss = loss_num / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) loss = loss_num / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss *= (
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
return loss return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8):
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_msa, num_classes=23,
)
loss = (
torch.sum(errors * bert_mask, dim=(-1, -2)) /
(eps + torch.sum(bert_mask, dim=(-1, -2)))
)
return loss
class AlphaFoldLoss(nn.Module):
""" Aggregation of the various losses described in the supplement """
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
cum_loss = 0
if("violation" not in out.keys() and self.config.violation.weight):
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
loss_fns = {
"distogram":
lambda: distogram_loss(
logits=out["distogram_logits"],
{**batch,
**self.config.distogram},
),
"experimentally_resolved":
lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved"],
{**batch,
**self.config.experimentally_resolved},
),
"fape":
lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt":
lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"]
{**batch,
**self.config.lddt},
),
"masked_msa":
lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
{**batch,
**self.config.masked_msa},
),
"supervised_chi":
lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
{**batch,
**self.config.supervised_chi},
),
"violation":
lambda: violation_loss(
out["violation"],
**batch,
),
}
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
cum_loss += weight * loss_fn()
return cum_loss
...@@ -180,4 +180,54 @@ config = mlc.ConfigDict({ ...@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
"max_outer_iterations": 20, "max_outer_iterations": 20,
"exclude_residues": [], "exclude_residues": [],
}, },
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.,
},
"fape": {
"backbone": {
"clamp_distance": 10.,
"loss_unit_distance": 10.,
"weight": 0.5,
}
"sidechain": {
"clamp_distance": 10.,
"length_scale": 10.,
"weight": 0.5,
}
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.,
"num_bins": 50,
"eps": 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": 1e-6,
"weight": 1.0,
},
"violation": {
"eps": 1e-6,
"weight": 0.,
},
},
}) })
...@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase): ...@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
a = torch.rand((batch_size, n, c_s)) a = torch.rand((batch_size, n, c_s))
a_initial = torch.rand((batch_size, n, c_s)) a_initial = torch.rand((batch_size, n, c_s))
a = ar(a, a_initial) _, a = ar(a, a_initial)
self.assertTrue(a.shape == (batch_size, n, no_angles, 2)) self.assertTrue(a.shape == (batch_size, n, no_angles, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment