Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
33b0a9df
Commit
33b0a9df
authored
Nov 25, 2021
by
Gustaf Ahdritz
Browse files
Add numpy version of DRMSD function
parent
6448d57c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
2 deletions
+15
-2
openfold/utils/loss.py
openfold/utils/loss.py
+15
-2
No files found.
openfold/utils/loss.py
View file @
33b0a9df
...
@@ -1409,7 +1409,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1409,7 +1409,11 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return
loss
return
loss
def
compute_drmsd
(
structure_1
,
structure_2
):
def
compute_drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
if
(
mask
is
not
None
):
structure_1
=
structure_1
*
mask
[...,
None
]
structure_2
=
structure_2
*
mask
[...,
None
]
d1
=
structure_1
[...,
:,
None
,
:]
-
structure_1
[...,
None
,
:,
:]
d1
=
structure_1
[...,
:,
None
,
:]
-
structure_1
[...,
None
,
:,
:]
d2
=
structure_2
[...,
:,
None
,
:]
-
structure_2
[...,
None
,
:,
:]
d2
=
structure_2
[...,
:,
None
,
:]
-
structure_2
[...,
None
,
:,
:]
...
@@ -1422,13 +1426,22 @@ def compute_drmsd(structure_1, structure_2):
...
@@ -1422,13 +1426,22 @@ def compute_drmsd(structure_1, structure_2):
drmsd
=
d1
-
d2
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
drmsd
=
drmsd
**
2
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
drmsd
=
torch
.
sqrt
(
drmsd
)
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
return
drmsd
def
compute_drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
compute_drmsd
(
structure_1
,
structure_2
,
mask
)
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
"""Aggregation of the various losses described in the supplement"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment