Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5621ac05
"examples/multimodal/utils/chat_processor.py" did not exist on "cab65e1a721d09b7977c2929d0be8e6b02c26ee4"
Commit
5621ac05
authored
Jun 26, 2023
by
Geoffrey Yu
Browse files
update the test input to be A2B3
parent
4666e15e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
17 deletions
+12
-17
tests/test_permutation.py
tests/test_permutation.py
+12
-17
No files found.
tests/test_permutation.py
View file @
5621ac05
...
@@ -21,9 +21,9 @@ import unittest
...
@@ -21,9 +21,9 @@ import unittest
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
tests.config
import
consts
from
tests.config
import
consts
from
.unifold_permutation
import
multi_chain_perm_align
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
import
os
import
os
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
In the test case, use PDB ID 1e4k as the label
"""
"""
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
label_ids
=
[
'label_1'
,
'label_2'
,
'label_2'
]
self
.
label_ids
=
[
'label_1'
,
'label_
1'
,
'label_2'
,
'label_
2'
,
'label_2'
]
self
.
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
+
[
4
]
*
13
+
[
5
]
*
13
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
+
9
n_res
=
len
(
self
.
asym_id
)
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
multimer_loss
=
AlphaFoldMultimerLoss
(
c
)
example_label
=
[
pickle
.
load
(
open
(
os
.
path
.
join
(
self
.
test_data_dir
,
f
"
{
i
}
.pkl"
),
'rb'
))
example_label
=
[
pickle
.
load
(
open
(
os
.
path
.
join
(
self
.
test_data_dir
,
f
"
{
i
}
.pkl"
),
'rb'
))
for
i
in
self
.
label_ids
]
for
i
in
self
.
label_ids
]
batch
=
{}
batch
=
{}
...
@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
...
@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
tf
,
c
.
model
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
print
(
f
"target_feat shape is
{
batch
[
'target_feat'
].
size
()
}
"
)
print
(
f
"batch_dim is
{
batch
[
'target_feat'
].
shape
[:
-
2
]
}
"
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
...
@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
...
@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# 2 chains
# #
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
asym_id
=
self
.
asym_id
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
1
3
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
3
9
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
13
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
#
batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
)
print
(
f
"max_recycling_iters is
{
c
.
data
.
common
.
max_recycling_iters
}
"
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
input_batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out
=
model
(
input_batch
)
out
=
model
(
batch
)
print
(
"finished running multimer forward"
)
permutated_labels
=
multimer_loss
(
out
,(
batch
,
example_label
))
print
(
f
"out is
{
type
(
out
)
}
and has keys
{
out
.
keys
()
}
"
)
print
(
f
"permuated_labels is
{
type
(
permutated_labels
)
}
and keys are:
\n
{
permutated_labels
.
keys
()
}
"
)
print
(
f
"final_atom_positions is
{
out
[
'final_atom_positions'
].
shape
}
"
)
\ No newline at end of file
print
(
f
"out itpm score is
{
out
[
'iptm_score'
]
}
"
)
multi_chain_perm_align
(
out
,
batch
,
example_label
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment