Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5621ac05
Commit
5621ac05
authored
Jun 26, 2023
by
Geoffrey Yu
Browse files
update the test input to be A2B3
parent
4666e15e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
17 deletions
+12
-17
tests/test_permutation.py
tests/test_permutation.py
+12
-17
No files found.
tests/test_permutation.py
View file @
5621ac05
...
@@ -21,9 +21,9 @@ import unittest
...
@@ -21,9 +21,9 @@ import unittest
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
tests.config
import
consts
from
tests.config
import
consts
from
.unifold_permutation
import
multi_chain_perm_align
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
import
os
import
os
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
In the test case, use PDB ID 1e4k as the label
"""
"""
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
label_ids
=
[
'label_1'
,
'label_2'
,
'label_2'
]
self
.
label_ids
=
[
'label_1'
,
'label_
1'
,
'label_2'
,
'label_
2'
,
'label_2'
]
self
.
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
+
[
4
]
*
13
+
[
5
]
*
13
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
+
9
n_res
=
len
(
self
.
asym_id
)
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
multimer_loss
=
AlphaFoldMultimerLoss
(
c
)
example_label
=
[
pickle
.
load
(
open
(
os
.
path
.
join
(
self
.
test_data_dir
,
f
"
{
i
}
.pkl"
),
'rb'
))
example_label
=
[
pickle
.
load
(
open
(
os
.
path
.
join
(
self
.
test_data_dir
,
f
"
{
i
}
.pkl"
),
'rb'
))
for
i
in
self
.
label_ids
]
for
i
in
self
.
label_ids
]
batch
=
{}
batch
=
{}
...
@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
...
@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
tf
,
c
.
model
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
print
(
f
"target_feat shape is
{
batch
[
'target_feat'
].
size
()
}
"
)
print
(
f
"batch_dim is
{
batch
[
'target_feat'
].
shape
[:
-
2
]
}
"
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
...
@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
...
@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# 2 chains
# #
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
asym_id
=
self
.
asym_id
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
1
3
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
3
9
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
13
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
#
batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
)
print
(
f
"max_recycling_iters is
{
c
.
data
.
common
.
max_recycling_iters
}
"
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
input_batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out
=
model
(
input_batch
)
out
=
model
(
batch
)
print
(
"finished running multimer forward"
)
permutated_labels
=
multimer_loss
(
out
,(
batch
,
example_label
))
print
(
f
"out is
{
type
(
out
)
}
and has keys
{
out
.
keys
()
}
"
)
print
(
f
"permuated_labels is
{
type
(
permutated_labels
)
}
and keys are:
\n
{
permutated_labels
.
keys
()
}
"
)
print
(
f
"final_atom_positions is
{
out
[
'final_atom_positions'
].
shape
}
"
)
\ No newline at end of file
print
(
f
"out itpm score is
{
out
[
'iptm_score'
]
}
"
)
multi_chain_perm_align
(
out
,
batch
,
example_label
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment