Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5f782370
Commit
5f782370
authored
May 10, 2024
by
Dingquan Yu
Browse files
Update tests and comments
parent
61191bff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
38 deletions
+99
-38
tests/test_permutation.py
tests/test_permutation.py
+99
-38
No files found.
tests/test_permutation.py
View file @
5f782370
...
@@ -48,15 +48,15 @@ class TestPermutation(unittest.TestCase):
...
@@ -48,15 +48,15 @@ class TestPermutation(unittest.TestCase):
self
.
chain_a_num_res
=
9
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
self
.
sym_id
=
self
.
asym_id
self
.
sym_id
=
self
.
asym_id
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
device
=
device
)
# @unittest.skip("skip for now")
def
test_1_selecting_anchors
(
self
):
def
test_1_selecting_anchors
(
self
):
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
...
@@ -64,20 +64,44 @@ class TestPermutation(unittest.TestCase):
...
@@ -64,20 +64,44 @@ class TestPermutation(unittest.TestCase):
'entity_id'
:
self
.
entity_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
'seq_length'
:
torch
.
tensor
([
57
])
}
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
])
anchor_gt_asym
=
int
(
anchor_gt_asym
)
anchor_gt_asym
=
int
(
anchor_gt_asym
)
anchor_pred_asym
=
{
int
(
i
)
for
i
in
anchor_pred_asym
}
anchor_pred_asym
=
{
int
(
i
)
for
i
in
anchor_pred_asym
}
expected_anchors
=
{
1
,
2
}
expected_anchors
=
{
1
,
2
}
expected_non_anchors
=
{
3
,
4
,
5
}
expected_non_anchors
=
{
3
,
4
,
5
}
self
.
assertIn
(
anchor_gt_asym
,
expected_anchors
)
self
.
assertIn
(
anchor_gt_asym
,
expected_anchors
)
self
.
assertNotIn
(
anchor_gt_asym
,
expected_non_anchors
)
self
.
assertNotIn
(
anchor_gt_asym
,
expected_non_anchors
)
# Check that predicted anchors are within expected anchor set
# Check that predicted anchors are within expected anchor set
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
anchor_pred_asym
,
expected_anchors
&
anchor_pred_asym
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
self
.
assertEqual
(
set
(),
anchor_pred_asym
&
expected_non_anchors
)
# @unittest.skip("skip for now")
def
test_2_permutation_pentamer
(
self
):
def
test_2_permutation_pentamer
(
self
):
"""
Test the permutation results on a pentamer A2B3, in which protein A has 9 residues
and protein B has 13 residues.
Expected outputs:
Only protein A should be selected as an anchor thus, in the output list, either [(0,1), (1,0)] or [(0,0), (1,1)] are allowed
The 3 chains from protein B should ALWAYS be aligned in a way that predicted b1 to be aligned with ground truth b1, pred b2 to ground truth b2
as shown below:
predicted structure: a2 - a1 - b2 - b3 - b1
indexes in the predicted list: 0 1 2 3 4
ground truth structure: a1 - a2 - b1 - b2 - b3
indexes in the ground truth list: 0 1 2 3 4
then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid.
However, the 3 protein B chains should be strictly aligned in the following order:
[(2,3), (3,4), (4,1)], regardless of how protein A chains are aligned.
Therefore, the only 2 correct permutations are :
[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and
[(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
"""
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'sym_id'
:
self
.
sym_id
,
...
@@ -87,7 +111,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -87,7 +111,7 @@ class TestPermutation(unittest.TestCase):
}
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
# create fake ground truth atom positions
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
...
@@ -95,16 +119,22 @@ class TestPermutation(unittest.TestCase):
...
@@ -95,16 +119,22 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
# Below permutate predicted chain positions
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
# Below permutate predicted chain positions
# here the b2 chain from the ground truth is deliberately put in b1 chain's position, and predicted b3 chain to b2's position
# and predicted b1 chain to b3's position
pred_atom_position
=
torch
.
cat
(
(
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
out
=
{
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_mask'
:
pred_atom_mask
}
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
(
(
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
...
@@ -114,13 +144,34 @@ class TestPermutation(unittest.TestCase):
...
@@ -114,13 +144,34 @@ class TestPermutation(unittest.TestCase):
batch
[
'all_atom_mask'
]
=
true_atom_mask
batch
[
'all_atom_mask'
]
=
true_atom_mask
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
batch
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
expected_asym_residue_index
=
{
self
.
assertIn
(
aligns
,
possible_outcome
)
1
:
torch
.
tensor
(
list
(
range
(
self
.
chain_a_num_res
))),
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
2
:
torch
.
tensor
(
list
(
range
(
self
.
chain_a_num_res
))),
3
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
))),
4
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
))),
5
:
torch
.
tensor
(
list
(
range
(
self
.
chain_b_num_res
)))
}
chain_a_permutated_chain_b_permutated
=
[
(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]
chain_a_not_permutated_chain_b_permutated
=
[
(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]
chain_a_permutated_chain_b_not_permuated
=
[
(
0
,
1
),
(
1
,
0
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]
chain_a_not_permutated_chain_b_not_permuated
=
[
(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]
# test on the permutation alignments
self
.
assertIn
(
aligns
,
[
chain_a_permutated_chain_b_permutated
,
chain_a_not_permutated_chain_b_permutated
])
self
.
assertNotIn
(
aligns
,
[
chain_a_permutated_chain_b_not_permuated
,
chain_a_not_permutated_chain_b_not_permuated
])
# test on the per_aysm_residue_index
for
k
,
v
in
expected_asym_residue_index
.
items
():
self
.
assertTrue
(
torch
.
equal
(
v
,
per_asym_residue_index
[
k
]))
# @unittest.skip("Test needs to be fixed post-refactor")
def
test_3_merge_labels
(
self
):
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
batch
=
{
...
@@ -132,7 +183,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -132,7 +183,7 @@ class TestPermutation(unittest.TestCase):
}
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
57
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
57
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
# create fake ground truth atom positions
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
...
@@ -140,42 +191,50 @@ class TestPermutation(unittest.TestCase):
...
@@ -140,42 +191,50 @@ class TestPermutation(unittest.TestCase):
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
# Below permutate predicted chain positions
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
(
(
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_mask'
:
pred_atom_mask
}
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
(
(
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
batch
[
'all_atom_mask'
]
=
true_atom_mask
# Below create a fake_input_features
# Below create a fake_input_features
fake_input_features
=
{
fake_input_features
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
'seq_length'
:
torch
.
tensor
([
57
])
}
}
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
1
,
325
)
fake_input_features
[
'asym_id'
]
=
fake_input_features
[
'asym_id'
].
reshape
(
fake_input_features
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
1
,
325
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
"residue_index"
]
=
pad_features
(
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
fake_input_features
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# NOTE
# NOTE
# batch: simulates ground_truth features
# batch: simulates ground_truth features
# fake_input_features: simulates the data that
gonna
be used as input for model.forward(fake_input_features)
# fake_input_features: simulates the data that
are going
be used as input for model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
# out: simulates the output of model.forward(fake_input_features)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
fake_input_features
,
fake_input_features
,
...
@@ -185,9 +244,11 @@ class TestPermutation(unittest.TestCase):
...
@@ -185,9 +244,11 @@ class TestPermutation(unittest.TestCase):
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment