Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
8baae516
Commit
8baae516
authored
Sep 20, 2023
by
Jennifer Wei
Browse files
Adds cuda wrapper to pytorch vectors to fix TestModel.test_dry_run
parent
48668ca3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
tests/test_model.py
tests/test_model.py
+11
-11
No files found.
tests/test_model.py
View file @
8baae516
...
@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
...
@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
.
cuda
()
model
.
eval
()
model
.
eval
()
batch
=
{}
batch
=
{}
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
.
cuda
()
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
model
.
input_embedder
.
tf_dim
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
).
float
()
.
cuda
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
.
cuda
()
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
.
cuda
()
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
.
cuda
()
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
).
float
()
.
cuda
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
.
cuda
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
.
cuda
()
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment