Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3ab9da6e
".github/vscode:/vscode.git/clone" did not exist on "d1d2d7e3d2a056ef0f57cad0cc31c4c4227fac5e"
Commit
3ab9da6e
authored
Jun 28, 2023
by
Geoffrey Yu
Browse files
move some tensors back to gpu
parent
a420160f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
10 deletions
+13
-10
openfold/utils/loss.py
openfold/utils/loss.py
+13
-10
No files found.
openfold/utils/loss.py
View file @
3ab9da6e
...
...
@@ -1704,6 +1704,7 @@ def kabsch_rotation(P, Q):
# Will continue trying SVD until the optimal rotaion is calculated
# #
try
:
# first need to load P and Q to cpu otherwise cannot extract the numpy matrices
rotation
=
procrustes
.
rotational
(
P
.
to
(
'cpu'
).
numpy
(),
Q
.
to
(
'cpu'
).
numpy
(),
translate
=
True
)
finished_rotation
=
True
...
...
@@ -1736,12 +1737,12 @@ def get_optimal_transform(
tgt_atoms
=
src_atoms
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
.
to
(
'cuda:0'
)
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
tgt_center
,
src_center
=
tgt_center
.
to
(
'c
p
u'
),
src_center
.
to
(
'c
pu'
)
# load to cpu memory just in case
x
=
tgt_center
-
src_center
@
r
tgt_center
,
src_center
=
tgt_center
.
to
(
'cu
da:0
'
),
src_center
.
to
(
'c
uda:0'
)
x
=
tgt_center
-
src_center
@
r
.
to
(
'cuda:0'
)
return
r
,
x
...
...
@@ -1752,9 +1753,11 @@ def compute_rmsd(
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
# shape check
true_atom_pos
=
true_atom_pos
.
to
(
'cuda:0'
)
pred_atom_pos
=
pred_atom_pos
.
to
(
'cuda:0'
)
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
if
atom_mask
is
not
None
:
sq_diff
=
sq_diff
[
atom_mask
]
sq_diff
=
sq_diff
.
to
(
'cpu'
)[
atom_mask
.
to
(
'cpu'
)]
# somehow it causes overflow on cuda so moved to cpu
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
...
...
@@ -1842,7 +1845,7 @@ def greedy_align(
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'c
p
u'
)
*
mask
.
to
(
'c
p
u'
)).
bool
()
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cu
da:0
'
)
*
mask
.
to
(
'cu
da:0
'
)).
bool
()
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
...
...
@@ -1901,7 +1904,7 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
violation
,
**
self
.
config
.
loss
.
violation
,
)
if
"renamed_atom14_gt_positions"
not
in
out
.
keys
():
...
...
@@ -2047,10 +2050,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
r
,
x
=
get_optimal_transform
(
anchor_true_pos
,
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'c
p
u'
)
*
anchor_pred_mask
.
to
(
'c
p
u'
)).
bool
(),
(
anchor_true_mask
.
to
(
'cu
da:0
'
)
*
anchor_pred_mask
.
to
(
'cu
da:0
'
)).
bool
(),
)
aligned_true_ca_poses
=
[
ca
.
to
(
'c
p
u'
)
@
r
.
to
(
'c
p
u'
)
+
x
.
to
(
'c
p
u'
)
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
'cu
da:0
'
)
@
r
.
to
(
'cu
da:0
'
)
+
x
.
to
(
'cu
da:0
'
)
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
...
...
@@ -2087,8 +2090,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
logger
.
info
(
"finished multi-chain permutation"
)
#
features.update(permutated_labels)
#
self.loss(out,features)
features
.
update
(
permutated_labels
)
self
.
loss
(
out
,
features
)
return
permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment