Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5fcd6ed2
Commit
5fcd6ed2
authored
Nov 02, 2023
by
Christina Floristean
Browse files
Unit test fixes for when AF2 is not installed
parent
f95d9a57
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
249 additions
and
229 deletions
+249
-229
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+40
-40
tests/test_feats.py
tests/test_feats.py
+11
-10
tests/test_loss.py
tests/test_loss.py
+11
-10
tests/test_model.py
tests/test_model.py
+11
-10
tests/test_multimer_datamodule.py
tests/test_multimer_datamodule.py
+17
-5
tests/test_permutation.py
tests/test_permutation.py
+115
-114
tests/test_structure_module.py
tests/test_structure_module.py
+22
-20
tests/test_template.py
tests/test_template.py
+22
-20
No files found.
openfold/utils/geometry/test_utils.py
View file @
5fcd6ed2
...
...
@@ -14,63 +14,63 @@
"""Shared utils for tests."""
import
dataclasses
import
torch
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
import
numpy
as
np
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
vector
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
np
.
testing
.
assert_array_
equal
(
assert
torch
.
equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_
array
(),
mat2
.
to_
array
(),
6
)
assert
torch
.
allclose
(
mat1
.
to_
tensor
(),
mat2
.
to_
tensor
(),
atol
=
1e-
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
np
.
ndarray
,
def
assert_array_equal_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
np
.
ndarray
,
assert
torch
.
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
assert
torch
.
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
assert
torch
.
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
assert
torch
.
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
assert
torch
.
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
assert
torch
.
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
assert
torch
.
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
assert
torch
.
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
assert
torch
.
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_
array
(),
array
,
6
)
assert
torch
.
allclose
(
matrix
.
to_
tensor
(),
array
,
atol
=
1e-
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_
equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_
equal
(
vec1
.
z
,
vec2
.
z
)
assert
torch
.
equal
(
vec1
.
x
,
vec2
.
x
)
assert
torch
.
equal
(
vec1
.
y
,
vec2
.
y
)
assert
torch
.
equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec
.
to_
array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
allclose
(
vec
.
to_
tensor
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec
.
to_
array
(),
array
)
def
assert_array_equal_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
equal
(
vec
.
to_
tensor
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
...
...
tests/test_feats.py
View file @
5fcd6ed2
...
...
@@ -45,6 +45,7 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
tests/test_loss.py
View file @
5fcd6ed2
...
...
@@ -79,6 +79,7 @@ def affine_vector_to_rigid(am_rigid, affine):
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
tests/test_model.py
View file @
5fcd6ed2
...
...
@@ -38,6 +38,7 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
tests/test_multimer_datamodule.py
View file @
5fcd6ed2
...
...
@@ -20,7 +20,8 @@ from openfold.utils.tensor_utils import tensor_tree_map
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
tests.config
import
consts
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -61,17 +62,28 @@ class TestMultimerDataModule(unittest.TestCase):
self
.
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
self
.
model
=
AlphaFold
(
self
.
c
)
self
.
multimer_
loss
=
AlphaFold
Multimer
Loss
(
self
.
c
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
self
.
c
.
loss
)
def
testPrepareData
(
self
):
self
.
data_module
.
prepare_data
()
self
.
data_module
.
setup
()
train_dataset
=
self
.
data_module
.
train_dataset
all_chain_features
,
ground_truth
=
train_dataset
[
1
]
all_chain_features
=
train_dataset
[
1
]
add_batch_size_dimension
=
lambda
t
:
(
t
.
unsqueeze
(
0
)
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
with
torch
.
no_grad
():
ground_truth
=
all_chain_features
.
pop
(
'gt_features'
,
None
)
# Run the model
out
=
self
.
model
(
all_chain_features
)
self
.
multimer_loss
(
out
,(
all_chain_features
,
ground_truth
))
\ No newline at end of file
# Remove the recycling dimension
all_chain_features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
all_chain_features
)
all_chain_features
=
multi_chain_permutation_align
(
out
=
out
,
features
=
all_chain_features
,
ground_truth
=
ground_truth
)
self
.
loss
(
out
,
all_chain_features
)
\ No newline at end of file
tests/test_permutation.py
View file @
5fcd6ed2
...
...
@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
unittest
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
get_least_asym_entity_or_longest_length
,
merge_labels
,
pad_features
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
math
from
openfold.utils.multi_chain_permutation
import
(
pad_features
,
get_least_asym_entity_or_longest_length
,
compute_permutation_alignment
,
split_ground_truth_labels
,
merge_labels
)
@
unittest
.
skip
(
"Tests need to be fixed post-refactor"
)
class
TestPermutation
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""
...
...
@@ -27,144 +29,143 @@ class TestPermutation(unittest.TestCase):
and rotation matrices
"""
theta
=
math
.
pi
/
4
theta
=
math
.
pi
/
4
device
=
'cpu'
self
.
rotation_matrix_z
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
],
device
=
'cuda'
)
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
],
device
=
device
)
self
.
rotation_matrix_x
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
device
)
self
.
rotation_matrix_y
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
device
)
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
'cuda'
)
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
self
.
sym_id
=
self
.
asym_id
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
'cuda'
)
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
def
test_1_selecting_anchors
(
self
):
self
.
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
self
.
batch
)
self
.
assertIn
(
int
(
anchor_gt_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),[
3
,
4
,
5
])
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
]
)
self
.
assertIn
(
int
(
anchor_gt_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),
[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),
[
3
,
4
,
5
])
def
test_2_permutation_pentamer
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
]
,
device
=
'cuda'
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
)),
dim
=
1
)
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
aligns
,
_
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
)),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# tensor_to_cuda = lambda t: t.to('cuda')
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
split_ground_truth_labels
(
batch
)
labels
=
merge_labels
(
labels
,
aligns
,
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
\ No newline at end of file
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
tests/test_structure_module.py
View file @
5fcd6ed2
...
...
@@ -46,6 +46,7 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
@@ -202,6 +203,7 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
tests/test_template.py
View file @
5fcd6ed2
...
...
@@ -56,6 +56,7 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
@@ -196,6 +197,7 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment