Commit 2e3404a1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in DRMSD function

parent 86a1f756
...@@ -1531,7 +1531,7 @@ def compute_drmsd(structure_1, structure_2, mask=None): ...@@ -1531,7 +1531,7 @@ def compute_drmsd(structure_1, structure_2, mask=None):
drmsd = drmsd ** 2 drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2)) drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else 0. drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd) drmsd = torch.sqrt(drmsd)
return drmsd return drmsd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment