Unverified Commit 959b3f25 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #337 from dingquanyu/fix-multimer-boolean-tensor-error

Fix multimer boolean tensor error
parents 82bda2d6 e84df271
...@@ -192,7 +192,7 @@ class AlphaFold(nn.Module): ...@@ -192,7 +192,7 @@ class AlphaFold(nn.Module):
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2 sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :] mask = mask[..., None] * mask[..., None, :]
sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape)))) sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
diff = torch.sqrt(sq_diff + eps) diff = torch.sqrt(sq_diff + eps).item()
return diff <= self.config.recycle_early_stop_tolerance return diff <= self.config.recycle_early_stop_tolerance
def iteration(self, feats, prevs, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment