Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a3de9cb9
Commit
a3de9cb9
authored
Oct 05, 2023
by
Christina Floristean
Browse files
Added kernel to template pair stack and updated tests
parent
a0985761
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
142 additions
and
52 deletions
+142
-52
openfold/model/model.py
openfold/model/model.py
+1
-0
openfold/model/template.py
openfold/model/template.py
+9
-3
tests/compare_utils.py
tests/compare_utils.py
+12
-1
tests/data_utils.py
tests/data_utils.py
+12
-0
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+90
-27
tests/test_primitives.py
tests/test_primitives.py
+18
-21
No files found.
openfold/model/model.py
View file @
a3de9cb9
...
...
@@ -169,6 +169,7 @@ class AlphaFold(nn.Module):
t_pair
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
...
...
openfold/model/template.py
View file @
a3de9cb9
...
...
@@ -20,7 +20,7 @@ from typing import Optional, List
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.model.primitives
import
LayerNorm
,
Attention
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutColumnwise
,
...
...
@@ -46,7 +46,6 @@ from openfold.utils.feats import (
from
openfold.utils.tensor_utils
import
(
add
,
permute_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
)
...
...
@@ -201,6 +200,7 @@ class TemplatePairStackBlock(nn.Module):
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
single
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
...
...
@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module):
single
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
...
...
@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module):
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module):
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
...
...
@@ -468,6 +472,7 @@ def embed_templates_offload(
t
.
unsqueeze
(
templ_dim
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
model
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
...
...
@@ -585,6 +590,7 @@ def embed_templates_average(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
model
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
...
...
tests/compare_utils.py
View file @
a3de9cb9
...
...
@@ -10,7 +10,6 @@ import numpy as np
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
tests.config
import
consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
...
...
@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os
.
environ
[
"JAX_PLATFORM_NAME"
]
=
"gpu"
def
skip_unless_ds4s_installed
():
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
ds4s_is_installed
=
deepspeed_is_installed
and
importlib
.
util
.
find_spec
(
"deepspeed.ops.deepspeed4science"
)
is
not
None
return
unittest
.
skipUnless
(
ds4s_is_installed
,
"Requires DeepSpeed with version ≥ 0.10.4"
)
def
skip_unless_flash_attn_installed
():
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
return
unittest
.
skipUnless
(
fa_is_installed
,
"Requires Flash Attention"
)
def
alphafold_is_installed
():
return
importlib
.
util
.
find_spec
(
"alphafold"
)
is
not
None
...
...
tests/data_utils.py
View file @
a3de9cb9
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
...
...
@@ -95,3 +96,14 @@ def random_affines_4x4(dim):
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
def
random_attention_inputs
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
inf
=
1e9
,
dtype
=
torch
.
float32
):
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
,
dtype
=
dtype
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
(
batch_size
,
n_seq
,
1
,
1
,
n
),
dtype
=
dtype
).
cuda
()
biases
=
[
inf
*
(
mask
-
1
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
biases
=
[
b
.
to
(
dtype
=
dtype
).
cuda
()
for
b
in
biases
]
return
q
,
kv
,
mask
,
biases
tests/test_deepspeed_evo_attention.py
View file @
a3de9cb9
...
...
@@ -22,45 +22,63 @@ import unittest
import
numpy
as
np
import
pickle
from
openfold.data
import
data_transforms
from
openfold.model.primitives
import
(
_attention
,
_deepspeed_evo_attn
lecun_normal_init_
,
Attention
,
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
tests.config
import
consts
import
tests.compare_utils
as
compare_utils
from
openfold.data
import
data_transforms
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
tests.data_utils
import
random_template_feats
,
random_attention_inputs
@
compare_utils
.
skip_unless_ds4s_installed
()
class
TestDeepSpeedKernel
(
unittest
.
TestCase
):
def
test_ds_kernel_vs_attention
(
self
):
def
compare_attention_types
(
self
,
use_flash
=
False
):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size
=
consts
.
batch_size
c_hidden
=
32
n_seq
=
consts
.
n_seq
n
=
2
**
12
n_seq
=
1
2
c_hidden
=
3
2
no_heads
=
4
dtype
=
torch
.
bfloat16
eps
=
2e-2
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
dtype
=
dtype
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
dtype
=
dtype
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
no_heads
,
c_hidden
,
dtype
=
dtype
).
cuda
()
q
,
kv
,
mask
,
biases
=
random_attention_inputs
(
batch_size
=
batch_size
,
n_seq
=
n_seq
,
n
=
n
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
)
bias
=
[
torch
.
rand
(
batch_size
,
n_seq
,
1
,
1
,
n
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
bias
=
[
b
.
to
(
dtype
=
dtype
).
cuda
()
for
b
in
bias
]
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
l
=
_deepspeed_evo_attn
(
q
,
k
,
v
,
biases
=
bias
).
cpu
()
lecun_normal_init_
(
a
.
linear_g
.
weight
)
lecun_normal_init_
(
a
.
linear_o
.
weight
)
if
use_flash
:
biases
=
[
biases
[
0
]]
flash_mask
=
mask
.
reshape
(
batch_size
*
n_seq
,
n
)
real_out
=
a
(
q
,
kv
,
use_flash
=
True
,
flash_mask
=
flash_mask
).
cpu
()
else
:
real_out
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
real
=
_attention
(
q
,
k
,
v
,
biases
=
bias
)
real
=
real
.
transpose
(
-
2
,
-
3
).
cpu
()
ds_out
=
a
(
q
,
kv
,
biases
=
biases
,
use_deepspeed_evo_attention
=
True
).
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
l
-
real
))
self
.
assertTrue
(
err
<
consts
.
eps
,
f
'Error:
{
err
}
'
)
err
=
torch
.
max
(
torch
.
abs
(
ds_out
-
real_out
))
self
.
assertTrue
(
err
<
eps
,
f
'Error:
{
err
}
'
)
def
test_ds_kernel_vs_attention
(
self
):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self
.
compare_attention_types
(
use_flash
=
False
)
@
compare_utils
.
skip_unless_flash_attn_installed
()
def
test_ds_kernel_vs_flash_attention
(
self
):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self
.
compare_attention_types
(
use_flash
=
True
)
def
compare_evoformer
(
self
,
dtype
):
"""
...
...
@@ -70,7 +88,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""
n_res
=
20
n_seq
=
18
eps
=
2e-2
eps
=
0.5
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
...
...
@@ -111,8 +129,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_msa
),
torch
.
abs
(
out_repro_msa_ds
),
atol
=
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_pair
),
torch
.
abs
(
out_repro_pair_ds
),
atol
=
eps
))
err
=
torch
.
mean
(
torch
.
abs
(
out_repro_msa
-
out_repro_msa_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'MSA Error:
{
err
}
'
)
err
=
torch
.
mean
(
torch
.
abs
(
out_repro_pair
-
out_repro_pair_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Pair Error
{
err
}
'
)
def
test_compare_evoformer_bf16
(
self
):
"""Run evoformer comparison test with BF16 precision."""
...
...
@@ -122,12 +143,54 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""Run evoformer comparison test with FP32 precision."""
self
.
compare_evoformer
(
torch
.
float32
)
def
test_compare_template_stack
(
self
):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ
=
consts
.
n_templ
n_res
=
20
eps
=
2e-2
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
inds
=
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
.
globals
.
use_deepspeed_evo_attention
=
False
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
inplace_safe
=
False
)
out_repro
=
out_repro
[
"template_pair_embedding"
].
cpu
()
model
.
globals
.
use_deepspeed_evo_attention
=
True
out_repro_ds
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
inplace_safe
=
False
)
out_repro_ds
=
out_repro_ds
[
"template_pair_embedding"
].
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Error
{
err
}
'
)
def
test_compare_model
(
self
):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
"""
eps
=
2e-2
eps
=
0.5
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
...
...
tests/test_primitives.py
View file @
a3de9cb9
...
...
@@ -13,40 +13,37 @@
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
from
openfold.model.primitives
import
(
_lma
,
_attention
,
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
lecun_normal_init_
,
Attention
,
)
from
tests.config
import
consts
from
tests.data_utils
import
random_attention_inputs
class
TestLMA
(
unittest
.
TestCase
):
def
test_lma_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
c_hidden
=
32
n
=
2
**
12
n_seq
=
12
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n_seq
,
no_heads
,
n
,
c_hidden
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n_seq
,
no_heads
,
n
,
c_hidden
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n_seq
,
no_heads
,
n
,
c_hidden
).
cuda
()
q
,
kv
,
_
,
biases
=
random_attention_inputs
(
batch_size
=
consts
.
batch_size
,
n_seq
=
consts
.
n_seq
,
n
=
2
**
12
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
)
bias
=
[
torch
.
rand
(
batch_size
,
n_seq
,
1
,
1
,
n
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
biases
=
[
b
.
cuda
()
for
b
in
bias
]
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
lma_biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q
.
shape
[
-
2
],)
+
(
k
.
shape
[
-
2
],))
for
b
in
biases
]
l
=
_lma
(
q
,
k
,
v
,
lma_biases
,
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
).
cpu
()
real
=
_attention
(
q
,
k
,
v
,
biases
).
cpu
()
lecun_normal_init_
(
a
.
linear_g
.
weight
)
lecun_normal_init_
(
a
.
linear_o
.
weight
)
l
=
a
(
q
,
kv
,
biases
=
biases
,
use_lma
=
True
).
cpu
()
real
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
l
-
real
))
self
.
assertTrue
(
err
<
consts
.
eps
,
f
'Error:
{
err
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment