Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fd748a0d
"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "3a41e8304e1ec5ff1688d1967fea5376581c5a5c"
Commit
fd748a0d
authored
Sep 25, 2023
by
Geoffrey Yu
Browse files
update loss to accomodate new input data pipeline
parent
02ce77c5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
12 deletions
+10
-12
openfold/utils/loss.py
openfold/utils/loss.py
+10
-12
No files found.
openfold/utils/loss.py
View file @
fd748a0d
...
...
@@ -1813,14 +1813,17 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
#
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
while
best_pred_asym
not
in
input_asym_id
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
# # # If there is more than one chain in the predicted output that has the same sequence
# # # as the chosen ground truth anchor, then randomly picke one
# if len(best_pred_asym) > 1:
# selected_best_pred_asym = random.choice(best_pred_asym)
# while selected_best_pred_asym not in input_asym_id:
# selected_best_pred_asym = random.choice(best_pred_asym)
# else:
# selected_best_pred_asym = best_pred_asym
best_pred_asym
=
least_asym_entities
[
0
]
return
least_asym_entities
[
0
],
best_pred_asym
...
...
@@ -2100,10 +2103,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
print
(
f
"##### line 2102 asym_mask is
{
asym_mask
}
and shape:
{
asym_mask
.
shape
}
"
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
print
(
f
"##### line 2104 anchor_pred_mask:
{
anchor_pred_mask
.
shape
}
and anchor_true_mask :
{
anchor_true_mask
.
shape
}
"
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
...
...
@@ -2139,12 +2140,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature
,
ground_truth
=
batch
print
(
f
"###### line 2140 feature asym_id is :
{
feature
[
'asym_id'
]
}
"
)
del
batch
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"###### anchor_gt_asym:
{
anchor_gt_asym
}
and anchor_pred_asym:
{
anchor_pred_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
...
...
@@ -2189,7 +2188,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
print
(
f
"finished permutation align. Align is
{
align
}
"
)
else
:
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment