Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fd748a0d
Commit
fd748a0d
authored
Sep 25, 2023
by
Geoffrey Yu
Browse files
update loss to accomodate new input data pipeline
parent
02ce77c5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
12 deletions
+10
-12
openfold/utils/loss.py
openfold/utils/loss.py
+10
-12
No files found.
openfold/utils/loss.py
View file @
fd748a0d
...
@@ -1813,14 +1813,17 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
...
@@ -1813,14 +1813,17 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
if
len
(
least_asym_entities
)
>
1
:
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
#
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# # If there is more than one chain in the predicted output that has the same sequence
# # # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
# # # as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
# if len(best_pred_asym) > 1:
while
best_pred_asym
not
in
input_asym_id
:
# selected_best_pred_asym = random.choice(best_pred_asym)
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
# while selected_best_pred_asym not in input_asym_id:
# selected_best_pred_asym = random.choice(best_pred_asym)
# else:
# selected_best_pred_asym = best_pred_asym
best_pred_asym
=
least_asym_entities
[
0
]
return
least_asym_entities
[
0
],
best_pred_asym
return
least_asym_entities
[
0
],
best_pred_asym
...
@@ -2100,10 +2103,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2100,10 +2103,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
print
(
f
"##### line 2102 asym_mask is
{
asym_mask
}
and shape:
{
asym_mask
.
shape
}
"
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
print
(
f
"##### line 2104 anchor_pred_mask:
{
anchor_pred_mask
.
shape
}
and anchor_true_mask :
{
anchor_true_mask
.
shape
}
"
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
return
input_mask
...
@@ -2139,12 +2140,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2139,12 +2140,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
"""
feature
,
ground_truth
=
batch
feature
,
ground_truth
=
batch
print
(
f
"###### line 2140 feature asym_id is :
{
feature
[
'asym_id'
]
}
"
)
del
batch
del
batch
if
permutate_chains
:
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"###### anchor_gt_asym:
{
anchor_gt_asym
}
and anchor_pred_asym:
{
anchor_pred_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
...
@@ -2189,7 +2188,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2189,7 +2188,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
true_ca_masks
,
aligned_true_ca_poses
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished permutation align. Align is
{
align
}
"
)
else
:
else
:
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment