Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f545323c
Commit
f545323c
authored
Oct 06, 2023
by
Christina Floristean
Browse files
Added test for backward pass
parent
a3de9cb9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
21 deletions
+95
-21
tests/data_utils.py
tests/data_utils.py
+12
-7
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+83
-14
No files found.
tests/data_utils.py
View file @
f545323c
...
@@ -98,12 +98,17 @@ def random_affines_4x4(dim):
...
@@ -98,12 +98,17 @@ def random_affines_4x4(dim):
return
affines
.
reshape
(
*
dim
,
4
,
4
)
return
affines
.
reshape
(
*
dim
,
4
,
4
)
def
random_attention_inputs
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
inf
=
1e9
,
dtype
=
torch
.
float32
):
def
random_attention_inputs
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
inf
=
1e9
,
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
).
cuda
()
dtype
=
torch
.
float32
,
requires_grad
=
False
):
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
).
cuda
()
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
,
requires_grad
=
requires_grad
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
,
requires_grad
=
requires_grad
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_seq
,
1
,
1
,
n
),
dtype
=
dtype
).
cuda
()
biases
=
[
inf
*
(
mask
-
1
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_seq
,
1
,
1
,
n
),
dtype
=
dtype
,
requires_grad
=
requires_grad
).
cuda
()
biases
=
[
b
.
to
(
dtype
=
dtype
).
cuda
()
for
b
in
biases
]
z_bias
=
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
,
dtype
=
dtype
,
requires_grad
=
requires_grad
).
cuda
()
mask_bias
=
inf
*
(
mask
-
1
)
if
requires_grad
:
mask_bias
=
mask_bias
.
detach
().
clone
().
requires_grad_
()
biases
=
[
mask_bias
,
z_bias
]
return
q
,
kv
,
mask
,
biases
return
q
,
kv
,
mask
,
biases
tests/test_deepspeed_evo_attention.py
View file @
f545323c
...
@@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic
...
@@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
"""
import
torch
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
pickle
import
pickle
import
torch
from
torch.nn
import
functional
as
F
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.primitives
import
(
from
openfold.model.primitives
import
(
lecun_normal_init_
,
lecun_normal_init_
,
Attention
,
Attention
)
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
@@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
def
compare_attention_types
(
self
,
use_flash
=
False
):
def
compare_attention_types
(
self
,
use_flash
=
False
):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_seq
=
18
n
=
2
**
12
n
_res
=
20
c_hidden
=
32
c_hidden
=
32
no_heads
=
4
no_heads
=
4
eps
=
2e-2
eps
=
2e-2
q
,
kv
,
mask
,
biases
=
random_attention_inputs
(
batch_size
=
batch_size
,
q
,
kv
,
mask
,
biases
=
random_attention_inputs
(
batch_size
=
batch_size
,
n_seq
=
n_seq
,
n_seq
=
n_seq
,
n
=
n
,
n
=
n
_res
,
no_heads
=
no_heads
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
)
c_hidden
=
c_hidden
)
...
@@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
if
use_flash
:
if
use_flash
:
biases
=
[
biases
[
0
]]
biases
=
[
biases
[
0
]]
flash_mask
=
mask
.
reshape
(
batch_size
*
n_seq
,
n
)
flash_mask
=
mask
.
reshape
(
batch_size
*
n_seq
,
n
_res
)
real_out
=
a
(
q
,
kv
,
use_flash
=
True
,
flash_mask
=
flash_mask
).
cpu
()
real_out
=
a
(
q
,
kv
,
use_flash
=
True
,
flash_mask
=
flash_mask
).
cpu
()
else
:
else
:
real_out
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
real_out
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
...
@@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase):
err
=
torch
.
max
(
torch
.
abs
(
ds_out
-
real_out
))
err
=
torch
.
max
(
torch
.
abs
(
ds_out
-
real_out
))
self
.
assertTrue
(
err
<
eps
,
f
'Error:
{
err
}
'
)
self
.
assertTrue
(
err
<
eps
,
f
'Error:
{
err
}
'
)
def
test_ds_kernel_vs_attention
(
self
):
def
test_ds_kernel_vs_attention
_forward
(
self
):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self
.
compare_attention_types
(
use_flash
=
False
)
self
.
compare_attention_types
(
use_flash
=
False
)
@
compare_utils
.
skip_unless_flash_attn_installed
()
@
compare_utils
.
skip_unless_flash_attn_installed
()
def
test_ds_kernel_vs_flash_att
ention
(
self
):
def
test_ds_kernel_vs_flash_att
n_forward
(
self
):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self
.
compare_attention_types
(
use_flash
=
True
)
self
.
compare_attention_types
(
use_flash
=
True
)
def
test_ds_kernel_vs_attention_backward
(
self
):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size
=
consts
.
batch_size
n_seq
=
18
n_res
=
20
c_hidden
=
32
no_heads
=
4
eps
=
consts
.
eps
q
,
kv
,
mask
,
biases
=
random_attention_inputs
(
batch_size
=
batch_size
,
n_seq
=
n_seq
,
n
=
n_res
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
,
requires_grad
=
True
)
attn
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
lecun_normal_init_
(
attn
.
linear_g
.
weight
)
lecun_normal_init_
(
attn
.
linear_o
.
weight
)
def
clone
(
t
):
t
=
t
.
clone
()
if
t
.
requires_grad
:
t
.
retain_grad
()
return
t
def
init_attn
():
a_clone
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a_clone
.
load_state_dict
(
attn
.
state_dict
())
return
a_clone
q_repro
=
clone
(
q
)
kv_repro
=
clone
(
kv
)
biases_repro
=
[
clone
(
b
)
for
b
in
biases
]
a
=
init_attn
()
out_repro
=
a
(
q_repro
,
kv_repro
,
biases
=
biases_repro
,
use_deepspeed_evo_attention
=
True
)
loss_repro
=
torch
.
mean
(
out_repro
)
loss_repro
.
backward
()
q_gt
=
clone
(
q
)
kv_gt
=
clone
(
kv
)
biases_gt
=
[
clone
(
b
)
for
b
in
biases
]
a
=
init_attn
()
out_gt
=
a
(
q_gt
,
kv_gt
,
biases
=
biases_gt
)
loss_gt
=
torch
.
mean
(
out_gt
)
loss_gt
.
backward
()
pairs
=
zip
([
q_repro
,
kv_repro
,
biases_repro
[
0
],
biases_repro
[
1
]],
[
q_gt
,
kv_gt
,
biases_gt
[
0
],
biases_gt
[
1
]])
for
i
,
item
in
enumerate
(
pairs
):
t_repro
,
t_gt
=
item
err
=
torch
.
max
(
torch
.
abs
(
t_repro
.
grad
.
cpu
()
-
t_gt
.
grad
.
cpu
()))
self
.
assertTrue
(
err
<
eps
,
f
'Error item #
{
i
}
:
{
err
}
'
)
def
compare_evoformer
(
self
,
dtype
):
def
compare_evoformer
(
self
,
dtype
):
"""
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
...
@@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""
"""
n_res
=
20
n_res
=
20
n_seq
=
18
n_seq
=
18
eps
=
0.5
c_m_shape
=
(
consts
.
c_m
,)
c_z_shape
=
(
consts
.
c_z
,)
eps
=
2e-2
activations
=
{
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
...
@@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase):
inplace_safe
=
False
,
inplace_safe
=
False
,
)
)
out_repro_msa
=
out_repro_msa
.
cpu
()
# In practice, layer norms applied later in the network make any
out_repro_pair
=
out_repro_pair
.
cpu
()
# kernel rounding errors negligible
out_repro_msa
=
F
.
layer_norm
(
out_repro_msa
,
c_m_shape
).
cpu
()
out_repro_pair
=
F
.
layer_norm
(
out_repro_pair
,
c_z_shape
).
cpu
()
out_repro_msa_ds
,
out_repro_pair_ds
=
model
.
evoformer
.
blocks
[
0
](
out_repro_msa_ds
,
out_repro_pair_ds
=
model
.
evoformer
.
blocks
[
0
](
activations
[
"msa"
],
activations
[
"msa"
],
...
@@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
_mask_trans
=
False
,
_mask_trans
=
False
,
inplace_safe
=
False
,
inplace_safe
=
False
,
)
)
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_msa_ds
=
F
.
layer_norm
(
out_repro_msa_ds
,
c_m_shape
)
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
out_repro_pair_ds
=
F
.
layer_norm
(
out_repro_pair_ds
,
c_z_shape
)
.
cpu
()
err
=
torch
.
mean
(
torch
.
abs
(
out_repro_msa
-
out_repro_msa_ds
))
err
=
torch
.
mean
(
torch
.
abs
(
out_repro_msa
-
out_repro_msa_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'MSA Error:
{
err
}
'
)
self
.
assertTrue
(
err
<
eps
,
f
'MSA Error:
{
err
}
'
)
...
@@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
def
test_compare_model
(
self
):
def
test_compare_model
(
self
):
"""
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
and compare output coordinates
.
"""
"""
eps
=
0.5
eps
=
0.5
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment