Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dfdd722c
Commit
dfdd722c
authored
Dec 30, 2021
by
Gustaf Ahdritz
Browse files
Fix bugs in data unit tests
parent
56849437
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
39 deletions
+17
-39
tests/test_data_transforms.py
tests/test_data_transforms.py
+17
-39
No files found.
tests/test_data_transforms.py
View file @
dfdd722c
...
...
@@ -9,7 +9,7 @@ import numpy as np
import
torch
import
unittest
from
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
from
openfold.
data.data_transforms
import
make_seq_mask
,
add_distillation_flag
,
make_all_atom_aatype
,
fix_templates_aatype
,
\
correct_msa_restypes
,
squeeze_features
,
randomly_replace_msa_with_unknown
,
MSA_FEATURE_NAMES
,
sample_msa
,
\
crop_extra_msa
,
delete_extra_msa
,
nearest_neighbor_clusters
,
make_msa_mask
,
make_hhblits_profile
,
make_masked_msa
,
\
make_msa_feat
,
crop_templates
,
make_atom14_masks
...
...
@@ -21,17 +21,15 @@ class TestDataTransforms(unittest.TestCase):
seq
=
torch
.
tensor
([
range
(
20
)],
dtype
=
torch
.
int64
).
transpose
(
0
,
1
)
seq_one_hot
=
torch
.
FloatTensor
(
seq
.
shape
[
0
],
20
).
zero_
()
seq_one_hot
.
scatter_
(
1
,
seq
,
1
)
protein_aatype
=
torch
.
tensor
(
seq_one_hot
)
protein_aatype
=
seq_one_hot
.
clone
().
detach
(
)
protein
=
{
'aatype'
:
protein_aatype
}
protein
=
make_seq_mask
(
protein
)
print
(
protein
)
assert
'seq_mask'
in
protein
assert
protein
[
'seq_mask'
].
shape
==
torch
.
Size
((
seq
.
shape
[
0
],
20
))
def
test_add_distillation_flag
(
self
):
protein
=
{}
protein
=
add_distillation_flag
.
__wrapped__
(
protein
,
True
)
print
(
protein
)
assert
'is_distillation'
in
protein
assert
protein
[
'is_distillation'
]
is
True
...
...
@@ -39,10 +37,9 @@ class TestDataTransforms(unittest.TestCase):
seq
=
torch
.
tensor
([
range
(
20
)],
dtype
=
torch
.
int64
).
transpose
(
0
,
1
)
seq_one_hot
=
torch
.
FloatTensor
(
seq
.
shape
[
0
],
20
).
zero_
()
seq_one_hot
.
scatter_
(
1
,
seq
,
1
)
protein_aatype
=
torch
.
tensor
(
seq_one_hot
)
protein_aatype
=
seq_one_hot
.
clone
().
detach
(
)
protein
=
{
'aatype'
:
protein_aatype
}
protein
=
make_all_atom_aatype
(
protein
)
print
(
protein
)
assert
'all_atom_aatype'
in
protein
assert
protein
[
'all_atom_aatype'
].
shape
==
protein
[
'aatype'
].
shape
...
...
@@ -51,26 +48,23 @@ class TestDataTransforms(unittest.TestCase):
template_seq
=
template_seq
.
unsqueeze
(
0
).
transpose
(
0
,
1
)
template_seq_one_hot
=
torch
.
FloatTensor
(
template_seq
.
shape
[
0
],
20
).
zero_
()
template_seq_one_hot
.
scatter_
(
1
,
template_seq
,
1
)
template_aatype
=
torch
.
tensor
(
template_seq_one_hot
).
unsqueeze
(
0
)
template_aatype
=
template_seq_one_hot
.
clone
().
detach
(
).
unsqueeze
(
0
)
protein
=
{
'template_aatype'
:
template_aatype
}
protein
=
fix_templates_aatype
(
protein
)
print
(
protein
)
template_seq_ours
=
torch
.
tensor
([[
0
,
4
,
3
,
6
,
13
,
7
,
8
,
9
,
11
,
10
,
12
,
2
,
14
,
5
,
1
,
15
,
16
,
19
,
17
,
18
]
*
2
])
assert
torch
.
all
(
torch
.
eq
(
protein
[
'template_aatype'
],
template_seq_ours
))
def
test_correct_msa_restypes
(
self
):
with
open
(
"
..
/test_data/features.pkl"
,
'rb'
)
as
file
:
with
open
(
"
tests
/test_data/features.pkl"
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
correct_msa_restypes
(
protein
)
print
(
protein
)
assert
torch
.
all
(
torch
.
eq
(
torch
.
tensor
(
features
[
'msa'
].
shape
),
torch
.
tensor
(
protein
[
'msa'
].
shape
)))
def
test_squeeze_features
(
self
):
with
open
(
"
..
/test_data/features.pkl"
,
"rb"
)
as
file
:
with
open
(
"
tests
/test_data/features.pkl"
,
"rb"
)
as
file
:
features
=
pickle
.
load
(
file
)
print
(
os
.
path
.
realpath
(
file
.
name
),
'Keys: '
,
features
.
keys
())
features_list
=
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
...
...
@@ -80,7 +74,6 @@ class TestDataTransforms(unittest.TestCase):
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
])}
for
k
in
features_list
:
if
k
in
features
:
print
(
k
,
features
[
k
].
dtype
)
if
k
in
[
'domain_name'
,
'sequence'
]:
protein
[
k
]
=
np
.
expand_dims
(
features
[
k
],
-
1
)
else
:
...
...
@@ -88,17 +81,15 @@ class TestDataTransforms(unittest.TestCase):
for
k
in
[
'seq_length'
,
'num_alignments'
]:
if
k
in
protein
:
protein
[
k
]
=
torch
.
tensor
(
protein
[
k
]
).
unsqueeze
(
0
)
protein
[
k
]
=
protein
[
k
].
clone
().
detach
(
).
unsqueeze
(
0
)
protein_squeezed
=
squeeze_features
(
protein
)
print
(
protein
)
for
k
in
features_list
:
if
k
in
protein
:
print
(
k
,
protein_squeezed
[
k
].
shape
,
features
[
k
].
shape
)
assert
protein_squeezed
[
k
].
shape
==
features
[
k
].
shape
def
test_randomly_replace_msa_with_unknown
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
]),
...
...
@@ -108,12 +99,9 @@ class TestDataTransforms(unittest.TestCase):
protein
=
randomly_replace_msa_with_unknown
.
__wrapped__
(
protein
,
replace_proportion
)
unknown_proportion_in_msa
=
torch
.
bincount
(
protein
[
'msa'
].
flatten
())
/
torch
.
numel
(
protein
[
'msa'
])
unknown_proportion_in_seq
=
torch
.
bincount
(
protein
[
'aatype'
].
flatten
())
/
torch
.
numel
(
protein
[
'aatype'
])
print
(
protein
)
print
(
'Proportion of X in MSA: '
,
unknown_proportion_in_msa
[
x_idx
])
print
(
'Proportion of X in sequence: '
,
unknown_proportion_in_seq
[
x_idx
])
def
test_sample_msa
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
max_seq
=
1000
...
...
@@ -124,17 +112,14 @@ class TestDataTransforms(unittest.TestCase):
protein
[
k
]
=
torch
.
tensor
(
features
[
k
])
protein_processed
=
sample_msa
.
__wrapped__
(
protein
.
copy
(),
max_seq
,
keep_extra
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
if
k
in
protein
and
keep_extra
:
assert
protein_processed
[
k
].
shape
[
0
]
==
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
assert
'extra_'
+
k
in
protein_processed
print
(
'extra_'
+
str
(
k
),
protein_processed
[
'extra_'
+
k
].
shape
)
print
(
'msa'
,
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
))
assert
protein_processed
[
'extra_'
+
k
].
shape
[
0
]
==
protein
[
k
].
shape
[
0
]
-
min
(
protein
[
k
].
shape
[
0
],
max_seq
)
def
test_crop_extra_msa
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
max_extra_msa
=
10
...
...
@@ -142,7 +127,6 @@ class TestDataTransforms(unittest.TestCase):
num_seq
=
protein
[
"extra_msa"
].
shape
[
0
]
protein
=
crop_extra_msa
.
__wrapped__
(
protein
,
max_extra_msa
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
if
"extra_"
+
k
in
protein
:
assert
protein
[
"extra_"
+
k
].
shape
[
0
]
==
min
(
max_extra_msa
,
num_seq
)
...
...
@@ -153,13 +137,12 @@ class TestDataTransforms(unittest.TestCase):
extra_msa_has_deletion_shape
[
2
]
=
1
protein
[
'extra_deletion_matrix'
]
=
torch
.
rand
(
extra_msa_has_deletion_shape
)
protein
=
delete_extra_msa
(
protein
)
print
(
protein
)
for
k
in
MSA_FEATURE_NAMES
:
assert
'extra_'
+
k
not
in
protein
assert
'extra_msa'
not
in
protein
def
test_nearest_neighbor_clusters
(
self
):
with
gzip
.
open
(
'
..
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
with
gzip
.
open
(
'
tests
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
features
=
pickle
.
load
(
f
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'true_msa'
][
0
],
dtype
=
torch
.
int64
),
...
...
@@ -167,22 +150,20 @@ class TestDataTransforms(unittest.TestCase):
'extra_msa'
:
torch
.
tensor
(
features
[
'extra_msa'
][
0
],
dtype
=
torch
.
int64
),
'extra_msa_mask'
:
torch
.
tensor
(
features
[
'extra_msa_mask'
][
0
],
dtype
=
torch
.
int64
)}
protein
=
nearest_neighbor_clusters
.
__wrapped__
(
protein
,
0
)
print
(
protein
)
assert
'extra_cluster_assignment'
in
protein
def
test_make_msa_mask
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
msa_mat
=
torch
.
tensor
(
features
[
'msa'
])
protein
=
{
'msa'
:
msa_mat
}
protein
=
make_msa_mask
(
protein
)
print
(
protein
)
assert
'msa_row_mask'
in
protein
assert
protein
[
'msa_row_mask'
].
shape
[
0
]
==
msa_mat
.
shape
[
0
]
def
test_make_hhblits_profile
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
...
...
@@ -191,14 +172,13 @@ class TestDataTransforms(unittest.TestCase):
assert
protein
[
'hhblits_profile'
].
shape
==
torch
.
Size
((
protein
[
'msa'
].
shape
[
1
],
22
))
def
test_make_masked_msa
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
)
print
(
protein
)
assert
'bert_mask'
in
protein
assert
'true_msa'
in
protein
assert
'msa'
in
protein
...
...
@@ -207,7 +187,7 @@ class TestDataTransforms(unittest.TestCase):
protein
[
'true_msa'
]
*
(
1
-
protein
[
'bert_mask'
]),
protein
[
'msa'
]
*
(
1
-
protein
[
'bert_mask'
])))
def
test_make_msa_feat
(
self
):
with
open
(
'
..
/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'
tests
/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'between_segment_residues'
:
torch
.
tensor
(
features
[
'between_segment_residues'
]),
...
...
@@ -221,7 +201,7 @@ class TestDataTransforms(unittest.TestCase):
assert
protein
[
'msa_feat'
].
shape
==
torch
.
Size
((
*
protein
[
'msa'
].
shape
,
25
))
def
test_crop_templates
(
self
):
with
gzip
.
open
(
'
..
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
with
gzip
.
open
(
'
tests
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
f
:
features
=
pickle
.
load
(
f
)
protein
=
{
'template_aatype'
:
torch
.
tensor
(
features
[
'true_msa'
][
0
]),
...
...
@@ -232,12 +212,11 @@ class TestDataTransforms(unittest.TestCase):
assert
protein
[
'template_all_atom_masks'
].
shape
[
0
]
==
max_templates
def
test_make_atom14_masks
(
self
):
with
gzip
.
open
(
'
..
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
file
:
with
gzip
.
open
(
'
tests
/test_data/sample_feats.pickle.gz'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
protein
=
{
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
][
0
])}
protein
=
make_atom14_masks
(
protein
)
print
(
protein
)
assert
'atom14_atom_exists'
in
protein
assert
'residx_atom14_to_atom37'
in
protein
assert
'residx_atom37_to_atom14'
in
protein
...
...
@@ -246,4 +225,3 @@ class TestDataTransforms(unittest.TestCase):
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment