"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "9fdb4435ac5375adba530b8813ef608dffb681c0"
Commit fe98aa32 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve loss documentation

parent 286b7f05
...@@ -46,7 +46,6 @@ class MSATransition(nn.Module): ...@@ -46,7 +46,6 @@ class MSATransition(nn.Module):
Implements Algorithm 9 Implements Algorithm 9
""" """
def __init__(self, c_m, n): def __init__(self, c_m, n):
""" """
Args: Args:
......
...@@ -84,6 +84,31 @@ def compute_fape( ...@@ -84,6 +84,31 @@ def compute_fape(
l1_clamp_distance: Optional[float] = None, l1_clamp_distance: Optional[float] = None,
eps=1e-8, eps=1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Computes FAPE loss.
Args:
pred_frames:
[*, N_frames] Rigid object of predicted frames
target_frames:
[*, N_frames] Rigid object of ground truth frames
frames_mask:
[*, N_frames] binary mask for the frames
pred_positions:
[*, N_pts, 3] predicted atom positions
target_positions:
[*, N_pts, 3] ground truth positions
positions_mask:
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
# [*, N_frames, N_pts, 3] # [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply( local_pred_pos = pred_frames.invert()[..., None].apply(
pred_positions[..., None, :, :], pred_positions[..., None, :, :],
...@@ -266,6 +291,29 @@ def supervised_chi_loss( ...@@ -266,6 +291,29 @@ def supervised_chi_loss(
eps=1e-6, eps=1e-6,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Implements Algorithm 27 (torsionAngleLoss)
Args:
angles_sin_cos:
[*, N, 7, 2] predicted angles
unnormalized_angles_sin_cos:
The same angles, but unnormalized
aatype:
[*, N] residue indices
seq_mask:
[*, N] sequence mask
chi_mask:
[*, N, 7] angle mask
chi_angles_sin_cos:
[*, N, 7, 2] ground truth angles
chi_weight:
Weight for the angle component of the loss
angle_norm_weight:
Weight for the normalization component of the loss
Returns:
[*] loss tensor
"""
pred_angles = angles_sin_cos[..., 3:, :] pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot( residue_type_one_hot = torch.nn.functional.one_hot(
aatype, aatype,
...@@ -1500,7 +1548,6 @@ def compute_drmsd_np(structure_1, structure_2, mask=None): ...@@ -1500,7 +1548,6 @@ def compute_drmsd_np(structure_1, structure_2, mask=None):
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
def __init__(self, config): def __init__(self, config):
super(AlphaFoldLoss, self).__init__() super(AlphaFoldLoss, self).__init__()
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment