Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bbf42cc5
"examples/vscode:/vscode.git/clone" did not exist on "531636933db8dd3a3ba4292d4eb00379fe702a44"
Commit
bbf42cc5
authored
Jun 29, 2023
by
Geoffrey Yu
Browse files
fixed the fape and backbone loss errors
parent
b22bd4e3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
openfold/utils/loss.py
openfold/utils/loss.py
+12
-6
No files found.
openfold/utils/loss.py
View file @
bbf42cc5
...
...
@@ -185,7 +185,13 @@ def backbone_loss(
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if
traj
.
shape
[
-
1
]
==
7
:
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
elif
traj
.
shape
[
-
1
]
==
4
:
pred_aff
=
Rigid
.
from_tensor_4x4
(
traj
)
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
...
...
@@ -304,10 +310,10 @@ def fape_loss(
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
int
erface
_backbone
},
**
{
**
batch
,
**
config
.
int
ra_chain
_backbone
},
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
int
erface
_backbone
.
weight
)
+
interface_bb_loss
*
config
.
int
ra_chain
_backbone
.
weight
)
else
:
bb_loss
=
backbone_loss
(
traj
=
traj
,
...
...
@@ -1865,7 +1871,6 @@ def greedy_align(
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_idx
=
j
print
(
f
"now best_idx is
{
best_idx
}
and rmsd is
{
rmsd
}
and j is
{
j
}
"
)
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
...
...
@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
loss
.
violation
,
**
self
.
config
.
violation
,
)
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
...
...
@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
"""
features
,
labels
=
batch
features
[
'resolution'
]
=
labels
[
2
][
'resolution'
]
# firstly update the resolution feature
# first remove the recycling dimention of input features
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
.
pop
(
'aatype'
)
logger
.
info
(
"finished multi-chain permutation"
)
logger
.
info
(
"finished multi-chain permutation
"
)
features
.
update
(
permutated_labels
)
move_to_cpu
=
lambda
t
:
(
t
.
to
(
'cpu'
))
features
=
tensor_tree_map
(
move_to_cpu
,
features
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment