Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2184eff0
Commit
2184eff0
authored
Sep 21, 2023
by
Geoffrey Yu
Browse files
revert back to
e097da95
for now
parent
01ef215e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
68 deletions
+39
-68
openfold/utils/loss.py
openfold/utils/loss.py
+39
-68
No files found.
openfold/utils/loss.py
View file @
2184eff0
...
@@ -1834,6 +1834,7 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1834,6 +1834,7 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -1846,7 +1847,6 @@ def greedy_align(
...
@@ -1846,7 +1847,6 @@ def greedy_align(
"""
"""
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
@@ -2064,57 +2064,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2064,57 +2064,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
return
labels
@
staticmethod
def
get_per_asym_residue_idex
(
batch
):
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
return
per_asym_residue_index
@
staticmethod
def
get_entity_2_asym_list
(
batch
):
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
,
anchor_residue_idx
):
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_residue_idx
,
true_ca_masks
,
ca_idx
,
out
,
asym_mask
):
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
]
# [bsz, nres]
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
,
anchor_residue_idx
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
print
(
f
"line 2109 is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is inf :
{
torch
.
isinf
(
pred_ca_pos
).
any
()
}
"
)
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
anchor_true_pos
[
0
],
mask
=
input_mask
[
0
]
)
return
r
,
x
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
):
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
):
"""
"""
...
@@ -2128,37 +2077,57 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2128,37 +2077,57 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_idex
(
batch
)
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_residue_idx
,
true_ca_masks
,
ca_idx
,
out
,
asym_mask
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
anchor_true_pos
[
0
],
mask
=
input_mask
[
0
]
)
del
input_mask
# just to save memory
del
anchor_pred_mask
del
anchor_true_mask
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
del
true_ca_poses
gc
.
collect
()
gc
.
collect
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
print
(
f
"line 2157 is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is inf : is nan
{
torch
.
isnan
(
pred_ca_pos
).
any
()
}
is nan
{
torch
.
isinf
(
pred_ca_pos
).
any
()
}
"
)
align
=
greedy_align
(
align
=
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -2168,6 +2137,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2168,6 +2137,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
aligned_true_ca_poses
,
true_ca_masks
del
aligned_true_ca_poses
,
true_ca_masks
del
r
,
x
del
r
,
x
del
pred_ca_pos
,
pred_ca_mask
del
anchor_pred_pos
,
anchor_true_pos
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment