Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ab659c39
Commit
ab659c39
authored
Jun 22, 2023
by
Geoffrey Yu
Browse files
update test codes
parent
4a66504c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
tests/test_permutation.py
tests/test_permutation.py
+4
-4
No files found.
tests/test_permutation.py
View file @
ab659c39
...
@@ -45,7 +45,7 @@ class TestPermutation(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class TestPermutation(unittest.TestCase):
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
+
13
n_res
=
consts
.
n_res
+
9
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
...
@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# 2 chains
# #
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
13
+
[
3
]
*
13
asym_id
=
[
1
]
*
9
+
[
2
]
*
9
+
[
3
]
*
13
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
13
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
18
+
[
2
]
*
13
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment