Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bbf42cc5
Commit
bbf42cc5
authored
Jun 29, 2023
by
Geoffrey Yu
Browse files
fixed the fape and backbone loss errors
parent
b22bd4e3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
openfold/utils/loss.py
openfold/utils/loss.py
+12
-6
No files found.
openfold/utils/loss.py
View file @
bbf42cc5
...
@@ -185,7 +185,13 @@ def backbone_loss(
...
@@ -185,7 +185,13 @@ def backbone_loss(
eps
:
float
=
1e-4
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if
traj
.
shape
[
-
1
]
==
7
:
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
elif
traj
.
shape
[
-
1
]
==
4
:
pred_aff
=
Rigid
.
from_tensor_4x4
(
traj
)
pred_aff
=
Rigid
(
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
...
@@ -304,10 +310,10 @@ def fape_loss(
...
@@ -304,10 +310,10 @@ def fape_loss(
interface_bb_loss
=
backbone_loss
(
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
int
erface
_backbone
},
**
{
**
batch
,
**
config
.
int
ra_chain
_backbone
},
)
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
int
erface
_backbone
.
weight
)
+
interface_bb_loss
*
config
.
int
ra_chain
_backbone
.
weight
)
else
:
else
:
bb_loss
=
backbone_loss
(
bb_loss
=
backbone_loss
(
traj
=
traj
,
traj
=
traj
,
...
@@ -1865,7 +1871,6 @@ def greedy_align(
...
@@ -1865,7 +1871,6 @@ def greedy_align(
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_idx
=
j
best_idx
=
j
print
(
f
"now best_idx is
{
best_idx
}
and rmsd is
{
rmsd
}
and j is
{
j
}
"
)
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
align
.
append
((
i
,
best_idx
))
...
@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
loss
.
violation
,
**
self
.
config
.
violation
,
)
)
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
...
@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
features
,
labels
=
batch
features
,
labels
=
batch
features
[
'resolution'
]
=
labels
[
2
][
'resolution'
]
# firstly update the resolution feature
# first remove the recycling dimention of input features
# first remove the recycling dimention of input features
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# then permutate ground truth chains before calculating the loss
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
.
pop
(
'aatype'
)
permutated_labels
.
pop
(
'aatype'
)
logger
.
info
(
"finished multi-chain permutation"
)
logger
.
info
(
"finished multi-chain permutation
"
)
features
.
update
(
permutated_labels
)
features
.
update
(
permutated_labels
)
move_to_cpu
=
lambda
t
:
(
t
.
to
(
'cpu'
))
move_to_cpu
=
lambda
t
:
(
t
.
to
(
'cpu'
))
features
=
tensor_tree_map
(
move_to_cpu
,
features
)
features
=
tensor_tree_map
(
move_to_cpu
,
features
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment