Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d909c707
Commit
d909c707
authored
Sep 19, 2023
by
Geoffrey Yu
Browse files
further cleaned up functions
parent
faca088f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
14 deletions
+16
-14
openfold/utils/loss.py
openfold/utils/loss.py
+16
-14
No files found.
openfold/utils/loss.py
View file @
d909c707
...
...
@@ -2095,14 +2095,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_residue_idx
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
true_ca_masks
,
ca_idx
,
out
,
asym_mask
):
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
]
# [bsz, nres]
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
,
anchor_residue_idx
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
print
(
f
"line 2109 is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is inf :
{
torch
.
isinf
(
pred_ca_pos
).
any
()
}
"
)
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
anchor_true_pos
[
0
],
...
...
@@ -2124,12 +2128,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
...
...
@@ -2141,18 +2139,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_residue_idx
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
true_ca_masks
,
ca_idx
,
out
,
asym_mask
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
gc
.
collect
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
print
(
f
"line 2157 is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is inf : is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is nan
{
torch
.
isinf
(
pred_ca_pos
).
any
()
}
"
)
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
...
...
@@ -2165,7 +2168,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
aligned_true_ca_poses
,
true_ca_masks
del
r
,
x
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment