Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f1563999
Commit
f1563999
authored
Nov 13, 2023
by
Christina Floristean
Browse files
Minor test fix
parent
7fb12cf5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
4 deletions
+3
-4
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+3
-4
No files found.
tests/test_deepspeed_evo_attention.py
View file @
f1563999
...
@@ -145,7 +145,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -145,7 +145,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
err
=
torch
.
max
(
torch
.
abs
(
t_repro
.
grad
.
cpu
()
-
t_gt
.
grad
.
cpu
()))
err
=
torch
.
max
(
torch
.
abs
(
t_repro
.
grad
.
cpu
()
-
t_gt
.
grad
.
cpu
()))
self
.
assertTrue
(
err
<
eps
,
f
'Error item #
{
i
}
:
{
err
}
'
)
self
.
assertTrue
(
err
<
eps
,
f
'Error item #
{
i
}
:
{
err
}
'
)
def
compare_evoformer
(
self
,
dtype
):
def
compare_evoformer
(
self
,
dtype
,
eps
):
"""
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
...
@@ -155,7 +155,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -155,7 +155,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_seq
=
18
n_seq
=
18
c_m_shape
=
(
consts
.
c_m
,)
c_m_shape
=
(
consts
.
c_m
,)
c_z_shape
=
(
consts
.
c_z
,)
c_z_shape
=
(
consts
.
c_z
,)
eps
=
5e-2
activations
=
{
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
...
@@ -206,11 +205,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -206,11 +205,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
def
test_compare_evoformer_bf16
(
self
):
def
test_compare_evoformer_bf16
(
self
):
"""Run evoformer comparison test with BF16 precision."""
"""Run evoformer comparison test with BF16 precision."""
self
.
compare_evoformer
(
torch
.
bfloat16
)
self
.
compare_evoformer
(
dtype
=
torch
.
bfloat16
,
eps
=
4e-2
)
def
test_compare_evoformer_fp32
(
self
):
def
test_compare_evoformer_fp32
(
self
):
"""Run evoformer comparison test with FP32 precision."""
"""Run evoformer comparison test with FP32 precision."""
self
.
compare_evoformer
(
torch
.
float32
)
self
.
compare_evoformer
(
dtype
=
torch
.
float32
,
eps
=
2e-2
)
def
test_compare_template_stack
(
self
):
def
test_compare_template_stack
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment