Commit 33b0a9df authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add numpy version of DRMSD function

parent 6448d57c
...@@ -1409,7 +1409,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1409,7 +1409,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss return loss
def compute_drmsd(structure_1, structure_2): def compute_drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
structure_1 = structure_1 * mask[..., None]
structure_2 = structure_2 * mask[..., None]
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :] d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :] d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
...@@ -1422,13 +1426,22 @@ def compute_drmsd(structure_1, structure_2): ...@@ -1422,13 +1426,22 @@ def compute_drmsd(structure_1, structure_2):
drmsd = d1 - d2 drmsd = d1 - d2
drmsd = drmsd ** 2 drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2)) drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) drmsd = drmsd * (1 / (n * (n - 1)))
drmsd = torch.sqrt(drmsd) drmsd = torch.sqrt(drmsd)
return drmsd return drmsd
def compute_drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return compute_drmsd(structure_1, structure_2, mask)
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment