Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6eb1afe7
Commit
6eb1afe7
authored
Jun 30, 2023
by
Geoffrey Yu
Browse files
finished constructing AlphaFoldMultimerLoss and other necessary changes in the losses calculations
parent
30d50a18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
11 deletions
+22
-11
openfold/utils/loss.py
openfold/utils/loss.py
+22
-11
No files found.
openfold/utils/loss.py
View file @
6eb1afe7
...
@@ -737,7 +737,11 @@ def tm_loss(
...
@@ -737,7 +737,11 @@ def tm_loss(
eps
=
1e-8
,
eps
=
1e-8
,
**
kwargs
,
**
kwargs
,
):
):
pred_affine
=
Rigid
.
from_tensor_7
(
final_affine_tensor
)
# first check whether this is a tensor_7 or tensor_4*4
if
final_affine_tensor
.
shape
[
-
1
]
==
7
:
pred_affine
=
Rigid
.
from_tensor_7
(
final_affine_tensor
)
elif
final_affine_tensor
.
shape
[
-
1
]
==
4
:
pred_affine
=
Rigid
.
from_tensor_4x4
(
final_affine_tensor
)
backbone_rigid
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
backbone_rigid
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
def
_points
(
affine
):
def
_points
(
affine
):
...
@@ -1635,7 +1639,7 @@ def chain_center_of_mass_loss(
...
@@ -1635,7 +1639,7 @@ def chain_center_of_mass_loss(
asym_id
:
torch
.
Tensor
,
asym_id
:
torch
.
Tensor
,
clamp_distance
:
float
=
-
4.0
,
clamp_distance
:
float
=
-
4.0
,
weight
:
float
=
0.05
,
weight
:
float
=
0.05
,
eps
:
float
=
1e-10
eps
:
float
=
1e-10
,
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
...
@@ -1662,9 +1666,9 @@ def chain_center_of_mass_loss(
...
@@ -1662,9 +1666,9 @@ def chain_center_of_mass_loss(
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
to
(
torch
.
int64
),
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
# make sure asym_id dtype is int
one_hot
=
one_hot
*
all_atom_mask
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
...
@@ -2012,8 +2016,12 @@ class AlphaFoldLoss(nn.Module):
...
@@ -2012,8 +2016,12 @@ class AlphaFoldLoss(nn.Module):
return
cum_loss
,
losses
return
cum_loss
,
losses
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
if
(
not
_return_breakdown
):
return
cum_loss
,
losses
cum_loss
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
,
losses
class
AlphaFoldMultimerLoss
(
AlphaFoldLoss
):
class
AlphaFoldMultimerLoss
(
AlphaFoldLoss
):
"""
"""
...
@@ -2120,11 +2128,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2120,11 +2128,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
.
pop
(
'aatype'
)
permutated_labels
.
pop
(
'aatype'
)
logger
.
info
(
"finished multi-chain permutation "
)
features
.
update
(
permutated_labels
)
features
.
update
(
permutated_labels
)
move_to_cpu
=
lambda
t
:
(
t
.
to
(
'cpu'
))
move_to_cpu
=
lambda
t
:
(
t
.
to
(
'cpu'
))
features
=
tensor_tree_map
(
move_to_cpu
,
features
)
features
=
tensor_tree_map
(
move_to_cpu
,
features
)
self
.
loss
(
out
,
features
)
if
(
not
_return_breakdown
):
return
permutated_labels
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
## TODO next need to check how the ground truth label is used
print
(
f
"cum_loss:
{
cum_loss
}
"
)
# in loss calculation.
return
cum_loss
\ No newline at end of file
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
losses:
{
losses
}
"
)
return
cum_loss
,
losses
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment