Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5fcd6ed2
"...dynamo-run/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f6d03f2f81f50d6a17bc58e02100b179cb1fb18f"
Commit
5fcd6ed2
authored
Nov 02, 2023
by
Christina Floristean
Browse files
Unit test fixes for when AF2 is not installed
parent
f95d9a57
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
249 additions
and
229 deletions
+249
-229
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+40
-40
tests/test_feats.py
tests/test_feats.py
+11
-10
tests/test_loss.py
tests/test_loss.py
+11
-10
tests/test_model.py
tests/test_model.py
+11
-10
tests/test_multimer_datamodule.py
tests/test_multimer_datamodule.py
+17
-5
tests/test_permutation.py
tests/test_permutation.py
+115
-114
tests/test_structure_module.py
tests/test_structure_module.py
+22
-20
tests/test_template.py
tests/test_template.py
+22
-20
No files found.
openfold/utils/geometry/test_utils.py
View file @
5fcd6ed2
...
...
@@ -14,84 +14,84 @@
"""Shared utils for tests."""
import
dataclasses
import
torch
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
import
numpy
as
np
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
vector
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
np
.
testing
.
assert_array_
equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
assert
torch
.
equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_
array
(),
mat2
.
to_
array
(),
6
)
assert
torch
.
allclose
(
mat1
.
to_
tensor
(),
mat2
.
to_
tensor
(),
atol
=
1e-
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
np
.
ndarray
,
def
assert_array_equal_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
np
.
ndarray
,
"""Check that array and Matrix match."""
assert
torch
.
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
assert
torch
.
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
assert
torch
.
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
assert
torch
.
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
assert
torch
.
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
assert
torch
.
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
assert
torch
.
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
assert
torch
.
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
assert
torch
.
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_
array
(),
array
,
6
)
assert
torch
.
allclose
(
matrix
.
to_
tensor
(),
array
,
atol
=
1e-
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_
equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_
equal
(
vec1
.
z
,
vec2
.
z
)
assert
torch
.
equal
(
vec1
.
x
,
vec2
.
x
)
assert
torch
.
equal
(
vec1
.
y
,
vec2
.
y
)
assert
torch
.
equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec
.
to_
array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
allclose
(
vec
.
to_
tensor
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec
.
to_
array
(),
array
)
def
assert_array_equal_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
equal
(
vec
.
to_
tensor
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
tests/test_feats.py
View file @
5fcd6ed2
...
...
@@ -45,16 +45,17 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
...
...
tests/test_loss.py
View file @
5fcd6ed2
...
...
@@ -79,16 +79,17 @@ def affine_vector_to_rigid(am_rigid, affine):
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
...
...
tests/test_model.py
View file @
5fcd6ed2
...
...
@@ -38,16 +38,17 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
...
...
tests/test_multimer_datamodule.py
View file @
5fcd6ed2
...
...
@@ -20,7 +20,8 @@ from openfold.utils.tensor_utils import tensor_tree_map
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
tests.config
import
consts
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -61,17 +62,28 @@ class TestMultimerDataModule(unittest.TestCase):
self
.
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
self
.
model
=
AlphaFold
(
self
.
c
)
self
.
multimer_
loss
=
AlphaFold
Multimer
Loss
(
self
.
c
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
self
.
c
.
loss
)
def
testPrepareData
(
self
):
self
.
data_module
.
prepare_data
()
self
.
data_module
.
setup
()
train_dataset
=
self
.
data_module
.
train_dataset
all_chain_features
,
ground_truth
=
train_dataset
[
1
]
all_chain_features
=
train_dataset
[
1
]
add_batch_size_dimension
=
lambda
t
:
(
t
.
unsqueeze
(
0
)
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
with
torch
.
no_grad
():
ground_truth
=
all_chain_features
.
pop
(
'gt_features'
,
None
)
# Run the model
out
=
self
.
model
(
all_chain_features
)
self
.
multimer_loss
(
out
,(
all_chain_features
,
ground_truth
))
\ No newline at end of file
# Remove the recycling dimension
all_chain_features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
all_chain_features
)
all_chain_features
=
multi_chain_permutation_align
(
out
=
out
,
features
=
all_chain_features
,
ground_truth
=
ground_truth
)
self
.
loss
(
out
,
all_chain_features
)
\ No newline at end of file
tests/test_permutation.py
View file @
5fcd6ed2
...
...
@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
unittest
from
openfold.utils.
loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
get_least_asym_entity_or_longest_length
,
merge_labels
,
pad_features
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
math
from
openfold.utils.
multi_chain_permutation
import
(
pad_features
,
get_least_asym_entity_or_longest_length
,
compute_permutation_alignment
,
split_ground_truth_labels
,
merge_labels
)
@
unittest
.
skip
(
"Tests need to be fixed post-refactor"
)
class
TestPermutation
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""
...
...
@@ -27,144 +29,143 @@ class TestPermutation(unittest.TestCase):
and rotation matrices
"""
theta
=
math
.
pi
/
4
theta
=
math
.
pi
/
4
device
=
'cpu'
self
.
rotation_matrix_z
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
],
device
=
'cuda'
)
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
],
device
=
device
)
self
.
rotation_matrix_x
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
device
)
self
.
rotation_matrix_y
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
device
)
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
'cuda'
)
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
self
.
sym_id
=
self
.
asym_id
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
'cuda'
)
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
def
test_1_selecting_anchors
(
self
):
self
.
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
self
.
batch
)
self
.
assertIn
(
int
(
anchor_gt_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),[
3
,
4
,
5
])
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
]
)
self
.
assertIn
(
int
(
anchor_gt_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),
[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),
[
3
,
4
,
5
])
def
test_2_permutation_pentamer
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
]
,
device
=
'cuda'
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
)),
dim
=
1
)
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
aligns
,
_
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
)),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
# tensor_to_cuda = lambda t: t.to('cuda')
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
batch
,
batch
)
print
(
f
"##### aligns is
{
aligns
}
"
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
labels
,
aligns
,
labels
=
split_ground_truth_labels
(
batch
)
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
\ No newline at end of file
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
tests/test_structure_module.py
View file @
5fcd6ed2
...
...
@@ -46,16 +46,17 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
...
...
@@ -202,16 +203,17 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
c_m
=
13
...
...
tests/test_template.py
View file @
5fcd6ed2
...
...
@@ -56,16 +56,17 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
...
...
@@ -196,16 +197,17 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment