Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d74b09cc
Commit
d74b09cc
authored
Feb 15, 2024
by
Geoffrey Yu
Browse files
fixed bugs in unittests for multi-chain permutation. now working on extra subtests
parent
df96b586
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
17 deletions
+33
-17
tests/test_permutation.py
tests/test_permutation.py
+33
-17
No files found.
tests/test_permutation.py
View file @
d74b09cc
...
...
@@ -56,6 +56,7 @@ class TestPermutation(unittest.TestCase):
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
# @unittest.skip("skip for now")
def
test_1_selecting_anchors
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
...
...
@@ -75,6 +76,7 @@ class TestPermutation(unittest.TestCase):
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
# @unittest.skip("skip for now")
def
test_2_permutation_pentamer
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
...
...
@@ -111,26 +113,25 @@ class TestPermutation(unittest.TestCase):
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
aligns
,
_
=
compute_permutation_alignment
(
out
,
batch
,
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
@
unittest
.
skip
(
"Test needs to be fixed post-refactor"
)
#
@unittest.skip("Test needs to be fixed post-refactor")
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
)
,
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
)
,
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
32
5
)),
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
5
7
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
32
5
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
5
7
)
batch
[
"residue_index"
]
=
torch
.
tensor
(
[
self
.
residue_index
]
)
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
...
...
@@ -155,15 +156,30 @@ class TestPermutation(unittest.TestCase):
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# tensor_to_cuda = lambda t: t.to('cuda')
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
# Below create a fake_input_features
fake_input_features
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
1
,
325
)
fake_input_features
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# NOTE
# batch: simulates ground_truth features
# fake_input_features: simulates the data that gonna be used as input for model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
fake_input_features
,
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
labels
=
split_ground_truth_labels
(
batch
)
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
...
...
@@ -173,5 +189,5 @@ class TestPermutation(unittest.TestCase):
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
#
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment