"lib/vscode:/vscode.git/clone" did not exist on "404a78e99cfe3b6b9a062db24d7832fb08b08765"
Commit fd56fb0a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix batched loss

parent 793362fb
...@@ -136,11 +136,11 @@ def backbone_loss( ...@@ -136,11 +136,11 @@ def backbone_loss(
fape_loss = compute_fape( fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[..., None, :], gt_aff[None],
backbone_affine_mask[..., None, :], backbone_affine_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[..., None, :].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[..., None, :], backbone_affine_mask[None],
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -148,11 +148,11 @@ def backbone_loss( ...@@ -148,11 +148,11 @@ def backbone_loss(
if use_clamped_fape is not None: if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
gt_aff[..., None, :], gt_aff[None],
backbone_affine_mask[..., None, :], backbone_affine_mask[None],
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[..., None, :].get_trans(), gt_aff[None].get_trans(),
backbone_affine_mask[..., None, :], backbone_affine_mask[None],
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -265,7 +265,7 @@ def supervised_chi_loss( ...@@ -265,7 +265,7 @@ def supervised_chi_loss(
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
) )
true_chi = chi_angles_sin_cos.unsqueeze(-4) true_chi = chi_angles_sin_cos[None]
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi true_chi_shifted = shifted_mask * true_chi
...@@ -282,8 +282,7 @@ def supervised_chi_loss( ...@@ -282,8 +282,7 @@ def supervised_chi_loss(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3) chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
) )
loss = 0 loss = chi_weight * sq_chi_loss
loss = loss + chi_weight * sq_chi_loss
angle_norm = torch.sqrt( angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment