Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
7f2a3267
"...llm/vllm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "b9a0ce2cabd274675137a5aeb09b42fd437509b1"
Commit
7f2a3267
authored
Jul 20, 2023
by
Geoffrey Yu
Browse files
update kabsch rotation calculation to avoid svd not converge error
parent
5348936a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
22 deletions
+16
-22
openfold/utils/loss.py
openfold/utils/loss.py
+16
-22
No files found.
openfold/utils/loss.py
View file @
7f2a3267
...
@@ -1710,36 +1710,30 @@ def kabsch_rotation(P, Q):
...
@@ -1710,36 +1710,30 @@ def kabsch_rotation(P, Q):
"""
"""
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
finished_rotation
=
False
rotation
=
procrustes
.
rotational
(
P
.
detach
().
cpu
().
numpy
(),
while
not
finished_rotation
:
Q
.
detach
().
cpu
().
numpy
(),
translate
=
False
,
scale
=
False
)
#
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
# rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
# Add a try-except block cuz sometimes SVD fails to converge and crashes the programme
# Will continue trying SVD until the optimal rotaion is calculated
# #
try
:
# first need to load P and Q to cpu otherwise cannot extract the numpy matrices
rotation
=
procrustes
.
rotational
(
P
.
to
(
'cpu'
).
numpy
(),
Q
.
to
(
'cpu'
).
numpy
(),
translate
=
True
)
finished_rotation
=
True
except
:
print
(
f
"svd failed."
)
import
sys
sys
.
exit
()
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
return
rotation
return
rotation
.
to
(
'cuda'
)
def
get_optimal_transform
(
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
src_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
mask
:
torch
.
Tensor
=
None
,
):
):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
[
-
1
]
==
3
assert
src_atoms
.
shape
[
-
1
]
==
3
if
torch
.
isnan
(
src_atoms
).
any
():
assert
len
(
mask
.
shape
)
==
1
,
"mask should have the shape of [num_res]"
if
torch
.
isnan
(
src_atoms
).
any
()
or
torch
.
isinf
(
src_atoms
).
any
():
#
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
# #
logging
.
warning
(
f
"src_atom has nan or inf"
)
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
1.0
)
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
1.0
)
if
mask
is
not
None
:
if
mask
is
not
None
:
...
@@ -1749,8 +1743,8 @@ def get_optimal_transform(
...
@@ -1749,8 +1743,8 @@ def get_optimal_transform(
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
tgt_atoms
=
src_atoms
tgt_atoms
=
src_atoms
else
:
else
:
src_atoms
=
src_atoms
.
to
(
'cuda:0'
)
[
mask
,
:]
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
.
to
(
'cuda:0'
)
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
...
@@ -2069,8 +2063,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2069,8 +2063,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# anchor_pred_mask = anchor_pred_mask.to('cuda')
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
r
,
x
=
get_optimal_transform
(
r
,
x
=
get_optimal_transform
(
anchor_true_pos
[
0
],
anchor_pred_pos
,
anchor_true_pos
[
0
],
anchor_pred_pos
,
mask
=
input_mask
mask
=
input_mask
[
0
]
)
)
del
input_mask
# just to save memory
del
input_mask
# just to save memory
del
anchor_pred_mask
del
anchor_pred_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment