Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
39830684
Commit
39830684
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
start working on multimer loss
parent
56d5e39c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
2 deletions
+27
-2
openfold/utils/loss.py
openfold/utils/loss.py
+27
-2
No files found.
openfold/utils/loss.py
View file @
39830684
...
@@ -34,7 +34,8 @@ from openfold.utils.tensor_utils import (
...
@@ -34,7 +34,8 @@ from openfold.utils.tensor_utils import (
permute_final_dims
,
permute_final_dims
,
batched_gather
,
batched_gather
,
)
)
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
softmax_cross_entropy
(
logits
,
labels
):
def
softmax_cross_entropy
(
logits
,
labels
):
loss
=
-
1
*
torch
.
sum
(
loss
=
-
1
*
torch
.
sum
(
...
@@ -1675,7 +1676,11 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1675,7 +1676,11 @@ class AlphaFoldLoss(nn.Module):
super
(
AlphaFoldLoss
,
self
).
__init__
()
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
def
loss
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
Rename previous forward() as loss
so that can be reused in the subclass
"""
if
"violation"
not
in
out
.
keys
():
if
"violation"
not
in
out
.
keys
():
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
...
@@ -1766,3 +1771,23 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1766,3 +1771,23 @@ class AlphaFoldLoss(nn.Module):
return
cum_loss
return
cum_loss
return
cum_loss
,
losses
return
cum_loss
,
losses
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
,
losses
class
AlphaFoldMultimerLoss
(
AlphaFoldLoss
):
"""
Add multi-chain permutation on top of
AlphaFoldLoss
"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
"""
logger
.
info
(
f
"out is
{
type
(
out
)
}
and batch is
{
type
(
batch
)
}
"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment