Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5fcd6ed2
Commit
5fcd6ed2
authored
Nov 02, 2023
by
Christina Floristean
Browse files
Unit test fixes for when AF2 is not installed
parent
f95d9a57
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
249 additions
and
229 deletions
+249
-229
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+40
-40
tests/test_feats.py
tests/test_feats.py
+11
-10
tests/test_loss.py
tests/test_loss.py
+11
-10
tests/test_model.py
tests/test_model.py
+11
-10
tests/test_multimer_datamodule.py
tests/test_multimer_datamodule.py
+17
-5
tests/test_permutation.py
tests/test_permutation.py
+115
-114
tests/test_structure_module.py
tests/test_structure_module.py
+22
-20
tests/test_template.py
tests/test_template.py
+22
-20
No files found.
openfold/utils/geometry/test_utils.py
View file @
5fcd6ed2
...
@@ -14,84 +14,84 @@
...
@@ -14,84 +14,84 @@
"""Shared utils for tests."""
"""Shared utils for tests."""
import
dataclasses
import
dataclasses
import
torch
from
alphafold.model.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
from
openfold.utils.geometry
import
vector
import
numpy
as
np
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
field
=
field
.
name
np
.
testing
.
assert_array_
equal
(
assert
torch
.
equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_
array
(),
mat2
.
to_
array
(),
6
)
assert
torch
.
allclose
(
mat1
.
to_
tensor
(),
mat2
.
to_
tensor
(),
atol
=
1e-
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
np
.
ndarray
,
def
assert_array_equal_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
assert
torch
.
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
assert
torch
.
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
assert
torch
.
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
assert
torch
.
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
assert
torch
.
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
assert
torch
.
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
assert
torch
.
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
assert
torch
.
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
assert
torch
.
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
np
.
ndarray
,
def
assert_array_close_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_
array
(),
array
,
6
)
assert
torch
.
allclose
(
matrix
.
to_
tensor
(),
array
,
atol
=
1e-
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec1
.
x
,
vec2
.
x
)
assert
torch
.
equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_
equal
(
vec1
.
y
,
vec2
.
y
)
assert
torch
.
equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_
equal
(
vec1
.
z
,
vec2
.
z
)
assert
torch
.
equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
def
assert_array_close_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_
allclose
(
vec
.
to_
array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec
.
to_
tensor
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
def
assert_array_equal_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_
equal
(
vec
.
to_
array
(),
array
)
assert
torch
.
equal
(
vec
.
to_
tensor
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
tests/test_feats.py
View file @
5fcd6ed2
...
@@ -45,16 +45,17 @@ if compare_utils.alphafold_is_installed():
...
@@ -45,16 +45,17 @@ if compare_utils.alphafold_is_installed():
class
TestFeats
(
unittest
.
TestCase
):
class
TestFeats
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_pseudo_beta_fn_compare
(
self
):
def
test_pseudo_beta_fn_compare
(
self
):
...
...
tests/test_loss.py
View file @
5fcd6ed2
...
@@ -79,16 +79,17 @@ def affine_vector_to_rigid(am_rigid, affine):
...
@@ -79,16 +79,17 @@ def affine_vector_to_rigid(am_rigid, affine):
class
TestLoss
(
unittest
.
TestCase
):
class
TestLoss
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_run_torsion_angle_loss
(
self
):
def
test_run_torsion_angle_loss
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
...
tests/test_model.py
View file @
5fcd6ed2
...
@@ -38,16 +38,17 @@ if compare_utils.alphafold_is_installed():
...
@@ -38,16 +38,17 @@ if compare_utils.alphafold_is_installed():
class
TestModel
(
unittest
.
TestCase
):
class
TestModel
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
...
...
tests/test_multimer_datamodule.py
View file @
5fcd6ed2
...
@@ -20,7 +20,8 @@ from openfold.utils.tensor_utils import tensor_tree_map
...
@@ -20,7 +20,8 @@ from openfold.utils.tensor_utils import tensor_tree_map
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
tests.config
import
consts
from
tests.config
import
consts
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -61,17 +62,28 @@ class TestMultimerDataModule(unittest.TestCase):
...
@@ -61,17 +62,28 @@ class TestMultimerDataModule(unittest.TestCase):
self
.
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
self
.
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
self
.
model
=
AlphaFold
(
self
.
c
)
self
.
model
=
AlphaFold
(
self
.
c
)
self
.
multimer_
loss
=
AlphaFold
Multimer
Loss
(
self
.
c
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
self
.
c
.
loss
)
def
testPrepareData
(
self
):
def
testPrepareData
(
self
):
self
.
data_module
.
prepare_data
()
self
.
data_module
.
prepare_data
()
self
.
data_module
.
setup
()
self
.
data_module
.
setup
()
train_dataset
=
self
.
data_module
.
train_dataset
train_dataset
=
self
.
data_module
.
train_dataset
all_chain_features
,
ground_truth
=
train_dataset
[
1
]
all_chain_features
=
train_dataset
[
1
]
add_batch_size_dimension
=
lambda
t
:
(
add_batch_size_dimension
=
lambda
t
:
(
t
.
unsqueeze
(
0
)
t
.
unsqueeze
(
0
)
)
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
all_chain_features
=
tensor_tree_map
(
add_batch_size_dimension
,
all_chain_features
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
ground_truth
=
all_chain_features
.
pop
(
'gt_features'
,
None
)
# Run the model
out
=
self
.
model
(
all_chain_features
)
out
=
self
.
model
(
all_chain_features
)
self
.
multimer_loss
(
out
,(
all_chain_features
,
ground_truth
))
\ No newline at end of file
# Remove the recycling dimension
all_chain_features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
all_chain_features
)
all_chain_features
=
multi_chain_permutation_align
(
out
=
out
,
features
=
all_chain_features
,
ground_truth
=
ground_truth
)
self
.
loss
(
out
,
all_chain_features
)
\ No newline at end of file
tests/test_permutation.py
View file @
5fcd6ed2
...
@@ -12,14 +12,16 @@
...
@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
torch
import
torch
import
unittest
import
unittest
from
openfold.utils.
loss
import
AlphaFoldMultimerLoss
from
openfold.utils.
multi_chain_permutation
import
(
pad_features
,
get_least_asym_entity_or_longest_length
,
from
openfold.utils.loss
import
get_least_asym_entity_or_longest_length
,
merge_labels
,
pad_features
compute_permutation_alignment
,
split_ground_truth_labels
,
from
openfold.utils.tensor_utils
import
tensor_tree_map
merge_labels
)
import
math
@
unittest
.
skip
(
"Tests need to be fixed post-refactor"
)
class
TestPermutation
(
unittest
.
TestCase
):
class
TestPermutation
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
"""
"""
...
@@ -27,144 +29,143 @@ class TestPermutation(unittest.TestCase):
...
@@ -27,144 +29,143 @@ class TestPermutation(unittest.TestCase):
and rotation matrices
and rotation matrices
"""
"""
theta
=
math
.
pi
/
4
theta
=
math
.
pi
/
4
device
=
'cpu'
self
.
rotation_matrix_z
=
torch
.
tensor
([
self
.
rotation_matrix_z
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
[
0
,
0
,
1
]
],
device
=
'cuda'
)
],
device
=
device
)
self
.
rotation_matrix_x
=
torch
.
tensor
([
self
.
rotation_matrix_x
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
],
device
=
device
)
self
.
rotation_matrix_y
=
torch
.
tensor
([
self
.
rotation_matrix_y
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
],
device
=
device
)
self
.
chain_a_num_res
=
9
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
'cuda'
)
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
device
)
self
.
sym_id
=
self
.
asym_id
self
.
sym_id
=
self
.
asym_id
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
'cuda'
)
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
device
)
def
test_1_selecting_anchors
(
self
):
def
test_1_selecting_anchors
(
self
):
self
.
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
'seq_length'
:
torch
.
tensor
([
57
])
}
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
self
.
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
,
batch
[
'asym_id'
]
)
self
.
assertIn
(
int
(
anchor_gt_asym
),[
1
,
2
])
self
.
assertIn
(
int
(
anchor_gt_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),[
3
,
4
,
5
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),
[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),[
1
,
2
])
self
.
assertIn
(
int
(
anchor_pred_asym
),
[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),[
3
,
4
,
5
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),
[
3
,
4
,
5
])
def
test_2_permutation_pentamer
(
self
):
def
test_2_permutation_pentamer
(
self
):
batch
=
{
batch
=
{
'asym_id'
:
self
.
asym_id
,
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
}
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
]
,
device
=
'cuda'
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
])
# create fake ground truth atom positions
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
out
=
{
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_mask'
:
pred_atom_mask
}
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)
,
device
=
'cuda'
)),
dim
=
1
)
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
batch
[
'all_atom_mask'
]
=
true_atom_mask
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
_
=
compute_permutation_alignment
(
out
,
batch
,
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
batch
)
dim_dict
,
permutate_chains
=
True
)
print
(
f
"##### aligns is
{
aligns
}
"
)
print
(
f
"##### aligns is
{
aligns
}
"
)
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
possible_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
3
),
(
3
,
4
),
(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
wrong_outcome
=
[[(
0
,
1
),
(
1
,
0
),
(
2
,
4
),
(
3
,
2
),
(
4
,
3
)],
[(
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
def
test_3_merge_labels
(
self
):
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
'seq_length'
:
torch
.
tensor
([
57
])
}
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
# create fake ground truth atom positions
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
chain_a1_pos
=
torch
.
randint
(
15
,
(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
)
,
device
=
'cuda'
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
))
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
'final_atom_mask'
:
pred_atom_mask
}
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
)),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
)),
dim
=
1
)
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
))),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
# tensor_to_cuda = lambda t: t.to('cuda')
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
,
aligns
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
batch
,
batch
)
dim_dict
,
permutate_chains
=
True
)
print
(
f
"##### aligns is
{
aligns
}
"
)
print
(
f
"##### aligns is
{
aligns
}
"
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
labels
=
split_ground_truth_labels
(
batch
)
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
labels
=
merge_labels
(
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
\ No newline at end of file
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
tests/test_structure_module.py
View file @
5fcd6ed2
...
@@ -46,16 +46,17 @@ if compare_utils.alphafold_is_installed():
...
@@ -46,16 +46,17 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -202,16 +203,17 @@ class TestStructureModule(unittest.TestCase):
...
@@ -202,16 +203,17 @@ class TestStructureModule(unittest.TestCase):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
c_m
=
13
c_m
=
13
...
...
tests/test_template.py
View file @
5fcd6ed2
...
@@ -56,16 +56,17 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
...
@@ -56,16 +56,17 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -196,16 +197,17 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -196,16 +197,17 @@ class TestTemplatePairStack(unittest.TestCase):
class
Template
(
unittest
.
TestCase
):
class
Template
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
if
consts
.
is_multimer
:
if
compare_utils
.
alphafold_is_installed
():
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
if
consts
.
is_multimer
:
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
else
:
cls
.
am_rigid
=
alphafold
.
model
.
geometry
cls
.
am_atom
=
alphafold
.
model
.
all_atom
else
:
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_rigid
=
alphafold
.
model
.
r3
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_compare
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment