"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f7a60cba6a078f283f575784e0110b25dc397b7f"
Unverified Commit 56b86074 authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge pull request #1 from aqlaboratory/main

Fix batched finetuning bugs
parents 103d0370 55003f16
...@@ -1353,7 +1353,10 @@ def violation_loss( ...@@ -1353,7 +1353,10 @@ def violation_loss(
+ l_clash + l_clash
) )
return loss # Average over the batch dimension
mean = torch.mean(loss)
return mean
def compute_renamed_ground_truth( def compute_renamed_ground_truth(
...@@ -1476,7 +1479,7 @@ def experimentally_resolved_loss( ...@@ -1476,7 +1479,7 @@ def experimentally_resolved_loss(
) -> torch.Tensor: ) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask) errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1) loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1))
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss * ( loss = loss * (
......
...@@ -29,7 +29,7 @@ def drmsd(structure_1, structure_2, mask=None): ...@@ -29,7 +29,7 @@ def drmsd(structure_1, structure_2, mask=None):
if(mask is not None): if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :]) drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2)) drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) n = d1.shape[-1] if mask is None else torch.min(torch.sum(mask, dim=-1))
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd) drmsd = torch.sqrt(drmsd)
......
...@@ -88,9 +88,10 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -88,9 +88,10 @@ class OpenFoldWrapper(pl.LightningModule):
) )
for k,v in other_metrics.items(): for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
v, torch.mean(v),
on_step=False, on_epoch=True, logger=True on_step=False, on_epoch=True, logger=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment