Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ea7fcced
Commit
ea7fcced
authored
Sep 21, 2023
by
Geoffrey Yu
Browse files
fixed merge_labels index error. Now working on cleaning up
parent
fe01bb0c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
9 deletions
+4
-9
openfold/utils/loss.py
openfold/utils/loss.py
+4
-9
No files found.
openfold/utils/loss.py
View file @
ea7fcced
...
...
@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
batch
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
...
...
@@ -1851,7 +1850,6 @@ def greedy_align(
best_rmsd
=
torch
.
inf
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
...
...
@@ -1859,10 +1857,7 @@ def greedy_align(
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
true_ca_poses
[
j
]
cropped_pos
=
torch
.
squeeze
(
cropped_pos
,
0
)
if
not
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
# this means selected candidte is not the correct one. Skip
used
[
j
]
=
True
else
:
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
mask
=
true_ca_masks
[
j
]
mask
=
torch
.
squeeze
(
mask
,
0
)
print
(
f
"cropped_pos shape:
{
cropped_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
"
)
...
...
@@ -1871,9 +1866,11 @@ def greedy_align(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
)
print
(
f
"rmsd is
{
rmsd
}
"
)
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_idx
=
j
print
(
f
"best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
...
...
@@ -1906,14 +1903,13 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
[
i
]
=
label
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
...
...
@@ -2138,7 +2134,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
gc
.
collect
()
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment