Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
60d0b15a
Unverified
Commit
60d0b15a
authored
Sep 21, 2023
by
jnwei
Committed by
GitHub
Sep 21, 2023
Browse files
Merge pull request #350 from aqlaboratory/fix-msastack-test-error
Fixes cuda/float wrapper error in unit tests
parents
2134cc09
73ff40b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
14 deletions
+14
-14
scripts/install_third_party_dependencies.sh
scripts/install_third_party_dependencies.sh
+1
-1
tests/test_evoformer.py
tests/test_evoformer.py
+2
-2
tests/test_model.py
tests/test_model.py
+11
-11
No files found.
scripts/install_third_party_dependencies.sh
View file @
60d0b15a
...
@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
...
@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources
bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data
# Decompress test data
gunzip
tests/test_data/sample_feats.pickle.gz
gunzip
-c
tests/test_data/sample_feats.pickle.gz
>
tests/test_data/sample_feats.pickle
tests/test_evoformer.py
View file @
60d0b15a
...
@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
n_res
,
),
),
device
=
"cuda"
,
device
=
"cuda"
,
)
)
.
float
()
pair_mask
=
torch
.
randint
(
pair_mask
=
torch
.
randint
(
0
,
0
,
2
,
2
,
...
@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
n_res
,
),
),
device
=
"cuda"
,
device
=
"cuda"
,
)
)
.
float
()
shape_z_before
=
z
.
shape
shape_z_before
=
z
.
shape
...
...
tests/test_model.py
View file @
60d0b15a
...
@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
...
@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
.
cuda
()
model
.
eval
()
model
.
eval
()
batch
=
{}
batch
=
{}
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
.
cuda
()
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
model
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
).
float
()
.
cuda
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
.
cuda
()
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
.
cuda
()
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
.
cuda
()
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
).
float
()
.
cuda
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
.
cuda
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
.
cuda
()
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment