Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0aa69474
Commit
0aa69474
authored
Feb 07, 2024
by
Jennifer
Browse files
fix deepspeed_evo_attention to work in both monomer and multimer settings.
parent
204ed191
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
4 deletions
+2
-4
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+2
-4
No files found.
tests/test_deepspeed_evo_attention.py
View file @
0aa69474
...
@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
batch
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
batch
.
items
()
if
k
.
startswith
(
"template_"
)
...
@@ -309,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -309,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
"residx_atom37_to_atom14"
].
long
()
].
long
()
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
21
).
to
(
torch
.
float32
)
# print(batch["target_feat"].shape)
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
consts
.
msa_logits
-
1
).
to
(
torch
.
float32
)
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
.
update
(
batch
.
update
(
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment