Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9a07b7f9
Commit
9a07b7f9
authored
Jan 17, 2024
by
Christina Floristean
Browse files
Fix deepspeed test for multimer
parent
9e057b7a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
5 deletions
+17
-5
openfold/model/primitives.py
openfold/model/primitives.py
+4
-2
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+13
-3
No files found.
openfold/model/primitives.py
View file @
9a07b7f9
...
...
@@ -193,13 +193,15 @@ class Linear(nn.Linear):
)
if
self
.
precision
is
not
None
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
bias
=
self
.
bias
.
to
(
dtype
=
self
.
precision
)
if
self
.
bias
is
not
None
else
None
return
nn
.
functional
.
linear
(
input
.
to
(
dtype
=
self
.
precision
),
self
.
weight
.
to
(
dtype
=
self
.
precision
),
self
.
bias
.
to
(
dtype
=
self
.
precision
)
).
to
(
dtype
=
d
)
bias
).
to
(
dtype
=
d
)
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
nn
.
functional
.
linear
(
input
,
self
.
weight
.
to
(
dtype
=
d
),
self
.
bias
.
to
(
dtype
=
d
))
bias
=
self
.
bias
.
to
(
dtype
=
d
)
if
self
.
bias
is
not
None
else
None
return
nn
.
functional
.
linear
(
input
,
self
.
weight
.
to
(
dtype
=
d
),
bias
)
return
nn
.
functional
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
...
...
tests/test_deepspeed_evo_attention.py
View file @
9a07b7f9
...
...
@@ -236,19 +236,28 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_res
=
20
eps
=
2e-2
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
batch
[
'asym_id'
][
0
]
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
template_feats
=
{
k
:
v
for
k
,
v
in
batch
.
items
()
if
k
.
startswith
(
"template_"
)
}
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
.
globals
.
use_deepspeed_evo_attention
=
False
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
,
batch
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
...
...
@@ -258,7 +267,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
model
.
globals
.
use_deepspeed_evo_attention
=
True
out_repro_ds
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
,
batch
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment