Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bf8788c7
"vscode:/vscode.git/clone" did not exist on "dba446122164aee843ac3d7d303f09ad1f29a0f6"
Commit
bf8788c7
authored
Feb 06, 2024
by
Jennifer
Browse files
debugging for permutation test
parent
de12c0ea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
10 deletions
+39
-10
openfold/utils/multi_chain_permutation.py
openfold/utils/multi_chain_permutation.py
+5
-5
tests/test_permutation.py
tests/test_permutation.py
+34
-5
No files found.
openfold/utils/multi_chain_permutation.py
View file @
bf8788c7
...
...
@@ -90,15 +90,15 @@ def get_optimal_transform(
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
First check how many subunit(s) one sequence has
, if there is no tie, e.g. AABBB then select
one of the A as anchor
First check how many subunit(s) one sequence has
. Select the subunit that is less
common, e.g. if the protein was AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
batch: in this fun
c
tion batch is the full ground truth features
input_asym_id: A list of a
s
ym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
...
...
@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the
short
est length
# If multiple entities have the least asym_id count, return those with the
long
est length
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
...
...
tests/test_permutation.py
View file @
bf8788c7
...
...
@@ -21,7 +21,7 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels
)
@
unittest
.
skip
(
"Tests need to be fixed post-refactor"
)
#
@unittest.skip("Tests need to be fixed post-refactor")
class
TestPermutation
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""
...
...
@@ -65,10 +65,39 @@ class TestPermutation(unittest.TestCase):
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
self
.
assertIn
(
int
(
anchor_gt_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),
[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),
[
3
,
4
,
5
])
anchor_gt_asym
=
int
(
anchor_gt_asym
)
anchor_pred_asym
=
{
int
(
i
)
for
i
in
anchor_pred_asym
}
expected_anchors
=
{
1
,
2
}
expected_non_anchors
=
{
3
,
4
,
5
}
self
.
assertIn
(
anchor_gt_asym
,
expected_anchors
)
self
.
assertNotIn
(
anchor_gt_asym
,
expected_non_anchors
)
# Check that predicted anchors are within expected anchor set
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
def
test_1_selecting_anchors_with_padding
(
self
):
# This test fails because it's looking for 0 as the
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
anchor_gt_asym
=
int
(
anchor_gt_asym
)
anchor_pred_asym
=
{
int
(
i
)
for
i
in
anchor_pred_asym
}
expected_anchors
=
{
1
,
2
}
expected_non_anchors
=
{
3
,
4
,
5
}
self
.
assertIn
(
anchor_gt_asym
,
expected_anchors
)
self
.
assertNotIn
(
anchor_gt_asym
,
expected_non_anchors
)
# Check that predicted anchors are within expected anchor set
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
def
test_2_permutation_pentamer
(
self
):
batch
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment