Commit fd56fb0a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix batched loss

parent 793362fb
......@@ -136,11 +136,11 @@ def backbone_loss(
fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
backbone_affine_mask[..., None, :],
gt_aff[None],
backbone_affine_mask[None],
pred_aff.get_trans(),
gt_aff[..., None, :].get_trans(),
backbone_affine_mask[..., None, :],
gt_aff[None].get_trans(),
backbone_affine_mask[None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
......@@ -148,11 +148,11 @@ def backbone_loss(
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
backbone_affine_mask[..., None, :],
gt_aff[None],
backbone_affine_mask[None],
pred_aff.get_trans(),
gt_aff[..., None, :].get_trans(),
backbone_affine_mask[..., None, :],
gt_aff[None].get_trans(),
backbone_affine_mask[None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
......@@ -265,7 +265,7 @@ def supervised_chi_loss(
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
true_chi = chi_angles_sin_cos.unsqueeze(-4)
true_chi = chi_angles_sin_cos[None]
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi
......@@ -282,8 +282,7 @@ def supervised_chi_loss(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
loss = 0
loss = loss + chi_weight * sq_chi_loss
loss = chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment