Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
7fdb503e
Commit
7fdb503e
authored
Jun 24, 2023
by
Gustaf Ahdritz
Browse files
Fix evoformer test
parent
d5da89c1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
tests/test_evoformer.py
tests/test_evoformer.py
+5
-3
No files found.
tests/test_evoformer.py
View file @
7fdb503e
...
@@ -193,10 +193,10 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -193,10 +193,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt
=
False
,
ckpt
=
False
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
).
eval
()
).
eval
()
.
cuda
()
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
)
,
device
=
"cuda"
)
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
)
,
device
=
"cuda"
)
msa_mask
=
torch
.
randint
(
msa_mask
=
torch
.
randint
(
0
,
0
,
2
,
2
,
...
@@ -205,6 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -205,6 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t
,
s_t
,
n_res
,
n_res
,
),
),
device
=
"cuda"
,
)
)
pair_mask
=
torch
.
randint
(
pair_mask
=
torch
.
randint
(
0
,
0
,
...
@@ -214,6 +215,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -214,6 +215,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
n_res
,
n_res
,
n_res
,
),
),
device
=
"cuda"
,
)
)
shape_z_before
=
z
.
shape
shape_z_before
=
z
.
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment