Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f563944a
Commit
f563944a
authored
Jun 22, 2023
by
Geoffrey Yu
Browse files
fixed get_anchor_candidates error
parent
c6ac105d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
15 deletions
+16
-15
tests/unifold_permutation.py
tests/unifold_permutation.py
+16
-15
No files found.
tests/unifold_permutation.py
View file @
f563944a
...
...
@@ -118,12 +118,14 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
asym_mask
=
asym_mask
[:,
0
]
# somehow need to adjust the
asym_mask
shape
print
(
f
"line 121 asym_mask is
{
asym_mask
}
"
)
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_anchor_candidates
(
batch
,
per_asym_residue_index
,
true_ca_masks
)
print
(
f
"anchor_gt_asym is
{
anchor_gt_asym
}
, anchor_pred_asym is
{
anchor_pred_asym
}
"
)
import
sys
sys
.
exit
()
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e9
...
...
@@ -200,20 +202,18 @@ def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
def
find_by_num_sym
(
min_num_sym
):
best_len
=
-
1
best_gt_asym
=
None
asym_ids
=
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
]
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
])
for
cur_asym_id
in
asym_ids
:
assert
cur_asym_id
>
0
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
j
=
int
(
cur_asym_id
-
2
)
# somehow have to change from -1 to -2
j
=
int
(
cur_asym_id
-
1
)
cur_true_mask
=
true_masks
[
j
][
cur_residue_index
]
cur_len
=
cur_true_mask
.
s
um
()
cur_len
=
cur_true_mask
.
s
hape
[
0
]
if
cur_len
>
best_len
:
best_len
=
cur_len
best_gt_asym
=
cur_asym_id
print
(
f
"finished selected the best anchor
\n
best_gt_asym is
{
best_gt_asym
}
and best_len is
{
best_len
}
"
)
return
best_gt_asym
,
best_len
sorted_num_sym
=
batch
[
"num_sym"
][
batch
[
"num_sym"
]
>
0
].
sort
()[
0
]
best_gt_asym
=
None
best_len
=
-
1
...
...
@@ -227,7 +227,6 @@ def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
if
best_len
>=
3
:
break
best_entity
=
batch
[
"entity_id"
][
batch
[
"asym_id"
]
==
best_gt_asym
][
0
]
print
(
f
"best_entity is
{
best_entity
}
\n
"
)
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
best_entity
])
return
best_gt_asym
,
best_pred_asym
...
...
@@ -258,13 +257,13 @@ def greedy_align(
used
[
i
]
=
True
continue
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
best_rmsd
=
1e
10
best_rmsd
=
1e
20
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
[:,
0
],:]
# only need the first 1 column of asym_mask
print
(
f
"line 266 cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
and pred_ca_pos is
{
pred_ca_pos
.
shape
}
"
)
print
(
f
"line 266
pred_ca_pos shape:
{
pred_ca_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
and pred_ca_pos is
{
pred_ca_pos
.
shape
}
"
)
cur_pred_mask
=
pred_ca_mask
[
asym_mask
[:,
0
]]
# only need the first column of asym_mask
for
next_asym_id
in
cur_asym_list
:
if
next_asym_id
==
0
:
...
...
@@ -273,19 +272,21 @@ def greedy_align(
if
not
used
[
j
]:
# possible candidate
print
(
f
"line 265 curr_residue_index is
{
cur_residue_index
}
and j is
{
j
}
"
)
print
(
f
"true_ca_poses shape:
{
true_ca_poses
[
j
].
shape
}
"
)
cropped_pos
=
true_ca_poses
[
j
][
cur_residue_index
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
[:,
0
]]
print
(
f
"line 278 cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
\n
mask shape:
{
mask
.
shape
}
"
)
print
(
f
"cropped_pos shape
{
cropped_pos
.
shape
}
and cur_pred_pos shape
{
cur_pred_pos
.
shape
}
"
)
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
*
mask
).
bool
()
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)
).
bool
()
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
print
(
f
"rmds is now
{
rmsd
}
and best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
print
(
f
"align is
{
align
}
"
)
return
align
...
...
@@ -319,4 +320,4 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
pad_v
=
new_v
.
new_zeros
((
num_res
-
merged_nres
,
*
pad_dim
))
new_v
=
torch
.
concat
((
new_v
,
pad_v
),
dim
=
0
)
outs
[
k
]
=
new_v
return
outs
return
outs
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment