Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cdfb0c75
Commit
cdfb0c75
authored
Oct 12, 2023
by
Geoffrey Yu
Browse files
move the kabsch rotation step to gpu
parent
a9d65037
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
11 deletions
+12
-11
openfold/utils/loss.py
openfold/utils/loss.py
+12
-11
No files found.
openfold/utils/loss.py
View file @
cdfb0c75
...
@@ -37,7 +37,6 @@ from openfold.utils.tensor_utils import (
...
@@ -37,7 +37,6 @@ from openfold.utils.tensor_utils import (
import
random
import
random
from
openfold.np
import
residue_constants
as
rc
from
openfold.np
import
residue_constants
as
rc
import
logging
import
logging
import
procrustes
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
gc
import
gc
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1716,9 +1715,8 @@ def kabsch_rotation(P, Q):
...
@@ -1716,9 +1715,8 @@ def kabsch_rotation(P, Q):
Use procrustes package to calculate best rotation that minimises
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
The optimal rotation matrix was calculated using Kabsch algorithm:
the rotational() function from procrustes package. Details can be found here:
https://en.wikipedia.org/wiki/Kabsch_algorithm
https://procrustes.qcdevs.org/api/rotational.html#rotational
Args:
Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
...
@@ -1727,15 +1725,18 @@ def kabsch_rotation(P, Q):
...
@@ -1727,15 +1725,18 @@ def kabsch_rotation(P, Q):
return:
return:
A 3*3 rotation matrix
A 3*3 rotation matrix
"""
"""
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
rotation
=
procrustes
.
rotational
(
P
.
detach
().
cpu
().
float
().
numpy
(),
Q
.
detach
().
cpu
().
float
().
numpy
(),
translate
=
False
,
scale
=
False
)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
return
rotation
.
to
(
device
=
P
.
device
,
dtype
=
P
.
dtype
)
# Firstly, compute SVD of P.T * Q
u
,
_
,
vt
=
torch
.
linalg
.
svd
(
torch
.
matmul
(
P
.
to
(
torch
.
float32
).
T
,
Q
.
to
(
torch
.
float32
)),
driver
=
'gesvd'
)
# Then construct s matrix
s
=
torch
.
eye
(
P
.
shape
[
1
],
device
=
P
.
device
)
# correct the rotation matrix to ensure a right-handed coordinate
s
[
-
1
,
-
1
]
=
torch
.
sign
(
torch
.
linalg
.
det
(
torch
.
matmul
(
u
,
vt
)))
# finally compute the rotation matrix
r_opt
=
torch
.
matmul
(
torch
.
matmul
(
u
,
s
),
vt
)
assert
r_opt
.
shape
==
torch
.
Size
([
3
,
3
])
return
r_opt
.
to
(
device
=
P
.
device
,
dtype
=
P
.
dtype
)
def
get_optimal_transform
(
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
src_atoms
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment