".github/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "ee4ef06bac4ee3f6bed53a3b77cb95c5ba5d824e"
Commit 4ca64437 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'multimer' of https://github.com/aqlaboratory/openfold into multimer

parents fdcb72e8 8332aa0e
...@@ -691,9 +691,13 @@ def compute_tm( ...@@ -691,9 +691,13 @@ def compute_tm(
n = residue_weights.shape[-1] n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32) pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface: if interface and (asym_id is not None):
if len(asym_id.shape)>1:
assert len(asym_id.shape)<=2
batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size,n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype) pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * ( pair_residue_weights = pair_mask * (
...@@ -1440,7 +1444,10 @@ def violation_loss( ...@@ -1440,7 +1444,10 @@ def violation_loss(
+ l_clash + l_clash
) )
return loss # Average over the batch dimension
mean = torch.mean(loss)
return mean
def compute_renamed_ground_truth( def compute_renamed_ground_truth(
...@@ -1563,7 +1570,7 @@ def experimentally_resolved_loss( ...@@ -1563,7 +1570,7 @@ def experimentally_resolved_loss(
) -> torch.Tensor: ) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask) errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1) loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2))) loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1))
loss = torch.sum(loss, dim=-1) loss = torch.sum(loss, dim=-1)
loss = loss * ( loss = loss * (
......
...@@ -29,7 +29,7 @@ def drmsd(structure_1, structure_2, mask=None): ...@@ -29,7 +29,7 @@ def drmsd(structure_1, structure_2, mask=None):
if(mask is not None): if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :]) drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2)) drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) n = d1.shape[-1] if mask is None else torch.min(torch.sum(mask, dim=-1))
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd) drmsd = torch.sqrt(drmsd)
......
...@@ -88,9 +88,10 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -88,9 +88,10 @@ class OpenFoldWrapper(pl.LightningModule):
) )
for k,v in other_metrics.items(): for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
v, torch.mean(v),
on_step=False, on_epoch=True, logger=True on_step=False, on_epoch=True, logger=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment