Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6f1329ef
Commit
6f1329ef
authored
May 10, 2024
by
Dingquan Yu
Committed by
Jennifer Wei
May 11, 2024
Browse files
Update tests and comments
parent
9a6eb649
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
38 deletions
+99
-38
tests/test_permutation.py
tests/test_permutation.py
+99
-38
No files found.
tests/test_permutation.py
View file @
6f1329ef
...
...
@@ -48,7 +48,8 @@ class TestPermutation(unittest.TestCase):
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
...
...
@@ -56,7 +57,6 @@ class TestPermutation(unittest.TestCase):
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
# @unittest.skip("skip for now")
def
test_1_selecting_anchors
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
...
...
@@ -64,7 +64,8 @@ class TestPermutation(unittest.TestCase):
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
anchor_gt_asym
=
int
(
anchor_gt_asym
)
anchor_pred_asym
=
{
int
(
i
)
for
i
in
anchor_pred_asym
}
expected_anchors
=
{
1
,
2
}
...
...
@@ -76,8 +77,31 @@ class TestPermutation(unittest.TestCase):
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
# @unittest.skip("skip for now")
def
test_2_permutation_pentamer
(
self
):
"""
Test the permutation results on a pentamer A2B3, in which protein A has 9 residues
and protein B has 13 residues.
Expected outputs:
Only protein A should be selected as an anchor thus, in the output list, either [(0,1), (1,0)] or [(0,0), (1,1)] are allowed
The 3 chains from protein B should ALWAYS be aligned in a way that predicted b1 to be aligned with ground truth b1, pred b2 to ground truth b2
as shown below:
predicted structure: a2 - a1 - b2 - b3 - b1
indexes in the predicted list: 0 1 2 3 4
ground truth structure: a1 - a2 - b1 - b2 - b3
indexes in the ground truth list: 0 1 2 3 4
then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid.
However, the 3 protein B chains should be strictly aligned in the following order:
[(2,3), (3,4), (4,1)], regardless of how protein A chains are aligned.
Therefore, the only 2 correct permutations are :
[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and
[(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
"""
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
...
...
@@ -95,16 +119,22 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
# here the b2 chain from the ground truth is deliberately put in b1 chain's position, and predicted b3 chain to b2's position
# and predicted b1 chain to b3's position
pred_atom_position
=
torch
.
cat
(
(
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
(
(
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
...
...
@@ -115,12 +145,33 @@ class TestPermutation(unittest.TestCase):
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
# @unittest.skip("Test needs to be fixed post-refactor")
expected_asym_residue_index
=
{
1
:
torch
.
tensor
(
list
(
range
(
self
.
chain_a_num_res
))),
2
:
torch
.
tensor
(
list
(
range
(
self
.
chain_a_num_res
))),
3
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
))),
4
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
))),
5
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
)))
}
chain_a_permutated_chain_b_permutated
=
[
(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]
chain_a_not_permutated_chain_b_permutated
=
[
(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]
chain_a_permutated_chain_b_not_permuated
=
[
(
0
,
1
),
(
1
,
0
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]
chain_a_not_permutated_chain_b_not_permuated
=
[
(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]
# test on the permutation alignments
self
.
assertIn
(
aligns
,
[
chain_a_permutated_chain_b_permutated
,
chain_a_not_permutated_chain_b_permutated
])
self
.
assertNotIn
(
aligns
,
[
chain_a_permutated_chain_b_not_permuated
,
chain_a_not_permutated_chain_b_not_permuated
])
# test on the per_aysm_residue_index
for
k
,
v
in
expected_asym_residue_index
.
items
():
self
.
assertTrue
(
torch
.
equal
(
v
,
per_asym_residue_index
[
k
]))
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
...
...
@@ -140,17 +191,21 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_position
=
torch
.
cat
(
(
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
(
(
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
...
...
@@ -168,14 +223,18 @@ class TestPermutation(unittest.TestCase):
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
1
,
325
)
fake_input_features
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
1
,
325
)
fake_input_features
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# NOTE
# batch: simulates ground_truth features
# fake_input_features: simulates the data that
gonna
be used as input for model.forward(fake_input_features)
# fake_input_features: simulates the data that
are going
be used as input for model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
fake_input_features
,
...
...
@@ -185,9 +244,11 @@ class TestPermutation(unittest.TestCase):
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment