Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5621ac05
Commit
5621ac05
authored
Jun 26, 2023
by
Geoffrey Yu
Browse files
update the test input to be A2B3
parent
4666e15e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
17 deletions
+12
-17
tests/test_permutation.py
tests/test_permutation.py
+12
-17
No files found.
tests/test_permutation.py
View file @
5621ac05
...
...
@@ -21,9 +21,9 @@ import unittest
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
tests.config
import
consts
from
.unifold_permutation
import
multi_chain_perm_align
import
logging
logger
=
logging
.
getLogger
(
__name__
)
import
os
...
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
"""
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
label_ids
=
[
'label_1'
,
'label_2'
,
'label_2'
]
self
.
label_ids
=
[
'label_1'
,
'label_
1'
,
'label_2'
,
'label_
2'
,
'label_2'
]
self
.
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
+
[
4
]
*
13
+
[
5
]
*
13
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
+
9
n_res
=
len
(
self
.
asym_id
)
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
...
@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
# deepspeed for this test
model
=
AlphaFold
(
c
)
multimer_loss
=
AlphaFoldMultimerLoss
(
c
)
example_label
=
[
pickle
.
load
(
open
(
os
.
path
.
join
(
self
.
test_data_dir
,
f
"
{
i
}
.pkl"
),
'rb'
))
for
i
in
self
.
label_ids
]
batch
=
{}
...
...
@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
print
(
f
"target_feat shape is
{
batch
[
'target_feat'
].
size
()
}
"
)
print
(
f
"batch_dim is
{
batch
[
'target_feat'
].
shape
[:
-
2
]
}
"
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
...
...
@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
asym_id
=
self
.
asym_id
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
1
3
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
3
9
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
13
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
#
batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
print
(
f
"max_recycling_iters is
{
c
.
data
.
common
.
max_recycling_iters
}
"
)
input_batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
out
=
model
(
input_batch
)
print
(
"finished running multimer forward"
)
print
(
f
"out is
{
type
(
out
)
}
and has keys
{
out
.
keys
()
}
"
)
print
(
f
"final_atom_positions is
{
out
[
'final_atom_positions'
].
shape
}
"
)
print
(
f
"out itpm score is
{
out
[
'iptm_score'
]
}
"
)
multi_chain_perm_align
(
out
,
batch
,
example_label
)
\ No newline at end of file
out
=
model
(
batch
)
permutated_labels
=
multimer_loss
(
out
,(
batch
,
example_label
))
print
(
f
"permuated_labels is
{
type
(
permutated_labels
)
}
and keys are:
\n
{
permutated_labels
.
keys
()
}
"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment