Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
80f0d617
Commit
80f0d617
authored
Jun 28, 2023
by
Geoffrey Yu
Browse files
solved cuda error just for now by moving the 2 tensors to cpu
parent
3ab9da6e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
8 deletions
+35
-8
openfold/utils/loss.py
openfold/utils/loss.py
+35
-8
No files found.
openfold/utils/loss.py
View file @
80f0d617
...
@@ -38,6 +38,8 @@ import random
...
@@ -38,6 +38,8 @@ import random
from
openfold.np
import
residue_constants
as
rc
from
openfold.np
import
residue_constants
as
rc
import
logging
import
logging
import
procrustes
import
procrustes
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
gc
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
softmax_cross_entropy
(
logits
,
labels
):
def
softmax_cross_entropy
(
logits
,
labels
):
...
@@ -842,6 +844,7 @@ def between_residue_bond_loss(
...
@@ -842,6 +844,7 @@ def between_residue_bond_loss(
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
1
]
]
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
...
@@ -1741,9 +1744,16 @@ def get_optimal_transform(
...
@@ -1741,9 +1744,16 @@ def get_optimal_transform(
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
del
src_atoms
,
tgt_atoms
,
gc
.
collect
()
tgt_center
,
src_center
=
tgt_center
.
to
(
'cuda:0'
),
src_center
.
to
(
'cuda:0'
)
tgt_center
,
src_center
=
tgt_center
.
to
(
'cuda:0'
),
src_center
.
to
(
'cuda:0'
)
x
=
tgt_center
-
src_center
@
r
.
to
(
'cuda:0'
)
x
=
tgt_center
.
to
(
'cpu'
)
-
src_center
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
return
r
,
x
del
tgt_center
,
src_center
,
mask
gc
.
collect
()
return
r
,
x
.
to
(
'cuda'
)
def
compute_rmsd
(
def
compute_rmsd
(
...
@@ -1756,6 +1766,9 @@ def compute_rmsd(
...
@@ -1756,6 +1766,9 @@ def compute_rmsd(
true_atom_pos
=
true_atom_pos
.
to
(
'cuda:0'
)
true_atom_pos
=
true_atom_pos
.
to
(
'cuda:0'
)
pred_atom_pos
=
pred_atom_pos
.
to
(
'cuda:0'
)
pred_atom_pos
=
pred_atom_pos
.
to
(
'cuda:0'
)
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
del
true_atom_pos
del
pred_atom_pos
gc
.
collect
()
if
atom_mask
is
not
None
:
if
atom_mask
is
not
None
:
sq_diff
=
sq_diff
.
to
(
'cpu'
)[
atom_mask
.
to
(
'cpu'
)]
# somehow it causes overflow on cuda so moved to cpu
sq_diff
=
sq_diff
.
to
(
'cpu'
)[
atom_mask
.
to
(
'cpu'
)]
# somehow it causes overflow on cuda so moved to cpu
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
mean
(
sq_diff
)
...
@@ -1830,7 +1843,7 @@ def greedy_align(
...
@@ -1830,7 +1843,7 @@ def greedy_align(
used
[
i
]
=
True
used
[
i
]
=
True
continue
continue
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
best_rmsd
=
1e20
best_rmsd
=
torch
.
inf
best_idx
=
None
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
...
@@ -1847,6 +1860,7 @@ def greedy_align(
...
@@ -1847,6 +1860,7 @@ def greedy_align(
rmsd
=
compute_rmsd
(
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cuda:0'
)
*
mask
.
to
(
'cuda:0'
)).
bool
()
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cuda:0'
)
*
mask
.
to
(
'cuda:0'
)).
bool
()
)
)
print
(
f
"rmsd is
{
rmsd
}
"
)
if
rmsd
<
best_rmsd
:
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_idx
=
j
best_idx
=
j
...
@@ -2047,13 +2061,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2047,13 +2061,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
input_mask
=
(
anchor_true_mask
.
to
(
'cuda:0'
)
*
anchor_pred_mask
.
to
(
'cuda:0'
)).
bool
()
r
,
x
=
get_optimal_transform
(
r
,
x
=
get_optimal_transform
(
anchor_true_pos
,
anchor_true_pos
,
anchor_pred_pos
,
anchor_pred_pos
,
mask
=
input_mask
(
anchor_true_mask
.
to
(
'cuda:0'
)
*
anchor_pred_mask
.
to
(
'cuda:0'
)).
bool
(),
)
)
del
input_mask
# just to save memory
del
anchor_pred_mask
del
anchor_true_mask
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
.
to
(
'cu
da:0
'
)
@
r
.
to
(
'cu
da:0
'
)
+
x
.
to
(
'cu
da:0
'
)
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
'c
p
u'
)
@
r
.
to
(
'c
p
u'
)
+
x
.
to
(
'c
p
u'
)
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
align
=
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
...
@@ -2064,6 +2082,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2064,6 +2082,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses
,
aligned_true_ca_poses
,
true_ca_masks
,
true_ca_masks
,
)
)
del
aligned_true_ca_poses
del
r
,
x
gc
.
collect
()
merged_labels
=
merge_labels
(
merged_labels
=
merge_labels
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
...
@@ -2091,6 +2114,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2091,6 +2114,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
logger
.
info
(
"finished multi-chain permutation"
)
logger
.
info
(
"finished multi-chain permutation"
)
features
.
update
(
permutated_labels
)
features
.
update
(
permutated_labels
)
move_to_gpu
=
lambda
t
:
(
t
.
to
(
'cuda:0'
))
features
=
tensor_tree_map
(
move_to_gpu
,
features
)
print
(
f
"after moving features:"
,
torch
.
cuda
.
memory_allocated
(
0
))
# out = tensor_tree_map(move_to_gpu,out)
self
.
loss
(
out
,
features
)
self
.
loss
(
out
,
features
)
return
permutated_labels
return
permutated_labels
## TODO next need to check how the ground truth label is used
## TODO next need to check how the ground truth label is used
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment