Commit 1a341511 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add DRMSD computation

parent 9f6b67f3
...@@ -1327,6 +1327,26 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1327,6 +1327,26 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss return loss
def compute_drmsd(structure_1, structure_2):
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
drmsd = d1 - d2
drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = structure_1.shape[-1]
drmsd = drmsd * (1 / (n * (n - 1)))
drmsd = torch.sqrt(drmsd)
return drmsd
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
""" Aggregation of the various losses described in the supplement """ """ Aggregation of the various losses described in the supplement """
def __init__(self, config): def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment