Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0df04f33
Commit
0df04f33
authored
Feb 15, 2024
by
Geoffrey Yu
Committed by
Jennifer Wei
May 11, 2024
Browse files
fixed bugs in unittests for multi-chain permutation. now working on extra subtests
parent
17b8c142
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
15 deletions
+33
-15
tests/test_permutation.py
tests/test_permutation.py
+33
-15
No files found.
tests/test_permutation.py
View file @
0df04f33
...
@@ -56,6 +56,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -56,6 +56,7 @@ class TestPermutation(unittest.TestCase):
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
device
=
device
)
# @unittest.skip("skip for now")
def
test_1_selecting_anchors
(
self
):
def
test_1_selecting_anchors
(
self
):
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
...
@@ -75,6 +76,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -75,6 +76,7 @@ class TestPermutation(unittest.TestCase):
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
# @unittest.skip("skip for now")
def
test_2_permutation_pentamer
(
self
):
def
test_2_permutation_pentamer
(
self
):
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
...
@@ -111,25 +113,25 @@ class TestPermutation(unittest.TestCase):
...
@@ -111,25 +113,25 @@ class TestPermutation(unittest.TestCase):
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
batch
[
'all_atom_mask'
]
=
true_atom_mask
aligns
,
_
=
compute_permutation_alignment
(
out
,
batch
,
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
batch
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
@
unittest
.
skip
(
"Test needs to be fixed post-refactor"
)
#
@unittest.skip("Test needs to be fixed post-refactor")
def
test_3_merge_labels
(
self
):
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
)
,
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
)
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
self
.
entity_id
,
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
32
5
)),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
5
7
)),
'seq_length'
:
torch
.
tensor
([
57
])
'seq_length'
:
torch
.
tensor
([
57
])
}
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
32
5
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
5
7
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
batch
[
"residue_index"
]
=
torch
.
tensor
(
[
self
.
residue_index
]
)
# create fake ground truth atom positions
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
...
@@ -154,13 +156,29 @@ class TestPermutation(unittest.TestCase):
...
@@ -154,13 +156,29 @@ class TestPermutation(unittest.TestCase):
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# tensor_to_cuda = lambda t: t.to('cuda')
batch
[
'all_atom_positions'
]
=
true_atom_position
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
batch
[
'all_atom_mask'
]
=
true_atom_mask
# Below create a fake_input_features
fake_input_features
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
1
,
325
)
fake_input_features
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# NOTE
# batch: simulates ground_truth features
# fake_input_features: simulates the data that gonna be used as input for model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
fake_input_features
,
batch
)
batch
)
labels
=
split_ground_truth_labels
(
batch
)
labels
=
split_ground_truth_labels
(
batch
)
...
@@ -171,5 +189,5 @@ class TestPermutation(unittest.TestCase):
...
@@ -171,5 +189,5 @@ class TestPermutation(unittest.TestCase):
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
#
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment