Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
458a62f7
".github/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "5944dbed0b6e08c6eeba9d8ade9bcbbc432da2f9"
Commit
458a62f7
authored
Jun 26, 2023
by
Geoffrey Yu
Browse files
switch optimal alignment method to procustes package
parent
82895ec3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
6 deletions
+30
-6
openfold/utils/loss.py
openfold/utils/loss.py
+30
-6
No files found.
openfold/utils/loss.py
View file @
458a62f7
...
@@ -37,7 +37,7 @@ from openfold.utils.tensor_utils import (
...
@@ -37,7 +37,7 @@ from openfold.utils.tensor_utils import (
import
random
import
random
from
openfold.np
import
residue_constants
as
rc
from
openfold.np
import
residue_constants
as
rc
import
logging
import
logging
from
scipy
import
spatial
as
sp_spatial
import
procrustes
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
softmax_cross_entropy
(
logits
,
labels
):
def
softmax_cross_entropy
(
logits
,
labels
):
...
@@ -1680,6 +1680,10 @@ def kabsch_rotation(P, Q):
...
@@ -1680,6 +1680,10 @@ def kabsch_rotation(P, Q):
Use scipy.spatial package to calculate best rotation that minimises
Use scipy.spatial package to calculate best rotation that minimises
the RMSD betwee P and Q
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here:
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args:
Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P
Q: [N * 3] the same dimension as P
...
@@ -1687,10 +1691,24 @@ def kabsch_rotation(P, Q):
...
@@ -1687,10 +1691,24 @@ def kabsch_rotation(P, Q):
return:
return:
A 3*3 rotation matrix
A 3*3 rotation matrix
"""
"""
assert
P
.
shape
==
torch
.
size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
rotation
,
_
=
sp_spatial
.
transform
.
Rotation
.
align_vectors
(
P
.
numpy
(),
Q
.
numpy
())
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
rotation
=
torch
.
tensor
(
rotation
,
dtype
=
torch
.
float64
)
finished_rotation
=
False
assert
rotation
.
shape
==
torch
.
size
([
3
,
3
])
while
not
finished_rotation
:
#
# Add a try-except block cuz sometimes SVD fails to converge and crashes the programme
# Will continue trying SVD until the optimal rotaion is calculated
# #
try
:
rotation
=
procrustes
.
rotational
(
P
.
to
(
'cpu'
).
numpy
(),
Q
.
to
(
'cpu'
).
numpy
(),
translate
=
True
)
finished_rotation
=
True
except
:
print
(
f
"svd failed."
)
import
sys
sys
.
exit
()
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
return
rotation
return
rotation
def
get_optimal_transform
(
def
get_optimal_transform
(
...
@@ -1700,6 +1718,12 @@ def get_optimal_transform(
...
@@ -1700,6 +1718,12 @@ def get_optimal_transform(
):
):
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
[
-
1
]
==
3
assert
src_atoms
.
shape
[
-
1
]
==
3
if
torch
.
isnan
(
src_atoms
).
any
():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
)
if
mask
is
not
None
:
if
mask
is
not
None
:
assert
mask
.
dtype
==
torch
.
bool
assert
mask
.
dtype
==
torch
.
bool
assert
mask
.
shape
[
-
1
]
==
src_atoms
.
shape
[
-
2
]
assert
mask
.
shape
[
-
1
]
==
src_atoms
.
shape
[
-
2
]
...
@@ -1711,7 +1735,7 @@ def get_optimal_transform(
...
@@ -1711,7 +1735,7 @@ def get_optimal_transform(
tgt_atoms
=
tgt_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
-
src_center
,
tgt_atoms
-
tgt_center
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
tgt_center
,
src_center
=
tgt_center
.
to
(
'cpu'
),
src_center
.
to
(
'cpu'
)
# load to cpu memory just in case
tgt_center
,
src_center
=
tgt_center
.
to
(
'cpu'
),
src_center
.
to
(
'cpu'
)
# load to cpu memory just in case
x
=
tgt_center
-
src_center
@
r
x
=
tgt_center
-
src_center
@
r
return
r
,
x
return
r
,
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment