Unverified Commit 56b86074 authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge pull request #1 from aqlaboratory/main

Fix batched finetuning bugs
parents 103d0370 55003f16
......@@ -1353,7 +1353,10 @@ def violation_loss(
+ l_clash
)
return loss
# Average over the batch dimension
mean = torch.mean(loss)
return mean
def compute_renamed_ground_truth(
......@@ -1476,7 +1479,7 @@ def experimentally_resolved_loss(
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1))
loss = torch.sum(loss, dim=-1)
loss = loss * (
......
......@@ -29,7 +29,7 @@ def drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
n = d1.shape[-1] if mask is None else torch.min(torch.sum(mask, dim=-1))
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
......
......@@ -88,9 +88,10 @@ class OpenFoldWrapper(pl.LightningModule):
)
for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log(
f"{phase}/{k}",
v,
f"{phase}/{k}",
torch.mean(v),
on_step=False, on_epoch=True, logger=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment