Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f0a320e0
"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e685140975ca203242e826e13b088654509d6620"
Commit
f0a320e0
authored
Sep 07, 2023
by
Christina Floristean
Browse files
Integrated deepspeed attention kernel and added initial tests.
parent
2134cc09
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
254 additions
and
37 deletions
+254
-37
deepspeed_config.json
deepspeed_config.json
+10
-1
openfold/config.py
openfold/config.py
+1
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+25
-3
openfold/model/model.py
openfold/model/model.py
+4
-0
openfold/model/msa.py
openfold/model/msa.py
+10
-3
openfold/model/primitives.py
openfold/model/primitives.py
+57
-29
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+5
-0
tests/config.py
tests/config.py
+1
-1
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+141
-0
No files found.
deepspeed_config.json
View file @
f0a320e0
...
...
@@ -10,9 +10,18 @@
"bfloat16"
:
{
"enabled"
:
true
},
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
1e-3
,
"eps"
:
1e-5
}
},
"zero_optimization"
:
{
"stage"
:
2
,
"cpu_offload"
:
true
,
"offload_optimizer"
:
{
"device"
:
"cpu"
},
"contiguous_gradients"
:
true
},
"activation_checkpointing"
:
{
...
...
openfold/config.py
View file @
f0a320e0
...
...
@@ -367,6 +367,7 @@ config = mlc.ConfigDict(
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"use_deepspeed_evo_attention"
:
False
,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma"
:
False
,
...
...
openfold/model/evoformer.py
View file @
f0a320e0
...
...
@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
...
...
@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
...
...
@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
...
@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module):
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
)
),
...
...
@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module):
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
),
...
...
@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module):
input_tensors
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
...
...
@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module):
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
use_lma
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_memory_efficient_kernel
=
not
(
use_lma
or
use_deepspeed_evo_attention
),
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
...
...
@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module):
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
...
...
@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
...
...
@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
inplace_safe
=
inplace_safe
,
...
...
@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module):
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
...
...
@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
...
@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module):
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
...
...
@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
...
...
@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module):
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
...
...
@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module):
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module):
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
...
...
@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module):
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
...
...
@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module):
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
...
...
openfold/model/model.py
View file @
f0a320e0
...
...
@@ -355,6 +355,7 @@ class AlphaFold(nn.Module):
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
...
...
@@ -367,6 +368,7 @@ class AlphaFold(nn.Module):
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
...
...
@@ -385,6 +387,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -397,6 +400,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
inplace_safe
=
inplace_safe
,
...
...
openfold/model/msa.py
View file @
f0a320e0
...
...
@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]],
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_memory_efficient_kernel
:
bool
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
flash_mask
:
Optional
[
torch
.
Tensor
],
...
...
@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
flash_mask
,
...
...
@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
...
@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
...
...
@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
...
...
@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
)
...
...
openfold/model/primitives.py
View file @
f0a320e0
...
...
@@ -12,20 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
importlib
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
(
deepspeed_is_installed
)
:
if
deepspeed_is_installed
:
import
deepspeed
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
fa_is_installed
):
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
if
fa_is_installed
:
from
flash_attn.bert_padding
import
unpad_input
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
...
...
@@ -33,7 +32,6 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
_chunk_slice
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
(
...
...
@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import (
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
...
...
@@ -196,9 +194,9 @@ class LayerNorm(nn.Module):
d
=
x
.
dtype
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
deepspeed
.
comm
.
comm
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
)
:
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
...
...
@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d
=
t
.
dtype
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
deepspeed
.
comm
.
comm
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
)
:
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
else
:
...
...
@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def
_attention_chunked_trainable
(
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
if
(
checkpoint
and
len
(
biases
)
>
2
)
:
if
checkpoint
and
len
(
biases
)
>
2
:
raise
ValueError
(
"Checkpointed version permits only permits two bias terms"
)
...
...
@@ -290,7 +288,7 @@ def _attention_chunked_trainable(
)
return
b
[
tuple
(
idx
)]
if
(
checkpoint
)
:
if
checkpoint
:
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
...
...
@@ -404,7 +402,7 @@ class Attention(nn.Module):
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
)
:
if
self
.
linear_g
is
not
None
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
...
...
@@ -425,11 +423,12 @@ class Attention(nn.Module):
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
use_flash
:
bool
=
False
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -444,6 +443,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
...
...
@@ -455,25 +458,25 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
if
(
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
)
)
:
if
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
):
raise
ValueError
(
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if
(
use_flash
and
biases
is
not
None
)
:
if
use_flash
and
biases
is
not
None
:
raise
ValueError
(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options
=
[
use_memory_efficient_kernel
,
use_lma
,
use_flash
]
if
(
sum
(
attn_options
)
>
1
)
:
attn_options
=
[
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
,
use_lma
,
use_flash
]
if
sum
(
attn_options
)
>
1
:
raise
ValueError
(
"Choose at most one alternative attention algorithm"
)
if
(
biases
is
None
)
:
if
biases
is
None
:
biases
=
[]
# [*, H, Q/K, C_hidden]
...
...
@@ -483,22 +486,47 @@ class Attention(nn.Module):
if
is_fp16_enabled
():
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
)
:
if
(
len
(
biases
)
>
2
)
:
if
use_memory_efficient_kernel
:
if
len
(
biases
)
>
2
:
raise
ValueError
(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_lma
):
elif
use_deepspeed_evo_attention
:
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
add_batch_dim
=
len
(
q
.
shape
)
<
5
if
add_batch_dim
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
biases
=
[
b
.
unsqueeze
(
0
)
for
b
in
biases
]
orig_dtype
=
q
.
dtype
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
k
.
to
(
dtype
=
torch
.
bfloat16
),
v
.
to
(
dtype
=
torch
.
bfloat16
),
[
b
.
to
(
dtype
=
torch
.
bfloat16
)
for
b
in
biases
])
o
=
o
.
to
(
dtype
=
orig_dtype
)
else
:
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
if
add_batch_dim
:
o
=
o
.
squeeze
(
0
)
elif
use_lma
:
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_q_chunk_size
,
lma_kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_flash
)
:
elif
use_flash
:
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
...
...
@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module):
v
=
self
.
linear_v
(
m
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
if
(
not
use_lma
)
:
if
not
use_lma
:
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
...
...
@@ -662,7 +690,7 @@ def _lma(
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
if
(
not
fa_is_installed
)
:
if
not
fa_is_installed
:
raise
ValueError
(
"_flash_attn requires that FlashAttention be installed"
)
...
...
@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens
,
q_max_s
,
kv_max_s
,
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
)
# [*, B, N, H, C]
...
...
openfold/model/triangular_attention.py
View file @
f0a320e0
...
...
@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
),
mha_inputs
,
...
...
@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
...
...
@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x
=
x
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
)
...
...
tests/config.py
View file @
f0a320e0
...
...
@@ -3,7 +3,7 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
{
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
20
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
...
...
tests/test_deepspeed_evo_attention.py
0 → 100644
View file @
f0a320e0
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
unittest
import
numpy
as
np
import
pickle
from
openfold.model.primitives
import
(
Attention
,
)
from
tests.config
import
consts
import
tests.compare_utils
as
compare_utils
from
openfold.data
import
data_transforms
from
openfold.utils.tensor_utils
import
tensor_tree_map
class
TestDeepSpeedKernel
(
unittest
.
TestCase
):
def
test_ds_kernel_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
c_hidden
=
32
n
=
2
**
12
n_seq
=
12
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
batch_size
,
n_seq
,
1
,
1
,
n
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_deepspeed_evo_attention
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
def
compare_evoformer
(
self
,
dtype
):
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
"pair"
:
torch
.
rand
(
n_res
,
n_res
,
consts
.
c_z
,
device
=
'cuda'
,
dtype
=
dtype
)
}
masks
=
{
"msa"
:
torch
.
randint
(
0
,
2
,
(
n_seq
,
n_res
),
device
=
'cuda'
,
dtype
=
dtype
),
"pair"
:
torch
.
randint
(
0
,
2
,
(
n_res
,
n_res
),
device
=
'cuda'
,
dtype
=
dtype
),
}
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
activations
[
"msa"
],
activations
[
"pair"
],
masks
[
"msa"
],
masks
[
"pair"
],
use_deepspeed_evo_attention
=
False
,
chunk_size
=
4
,
_mask_trans
=
False
,
inplace_safe
=
False
,
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_msa_ds
,
out_repro_pair_ds
=
model
.
evoformer
.
blocks
[
0
](
activations
[
"msa"
],
activations
[
"pair"
],
masks
[
"msa"
],
masks
[
"pair"
],
use_deepspeed_evo_attention
=
True
,
chunk_size
=
4
,
_mask_trans
=
False
,
inplace_safe
=
False
,
)
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_msa
),
torch
.
abs
(
out_repro_msa_ds
),
atol
=
consts
.
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_pair
),
torch
.
abs
(
out_repro_pair_ds
),
atol
=
consts
.
eps
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare_evoformer_bf16
(
self
):
self
.
compare_evoformer
(
torch
.
bfloat16
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare_evoformer_fp32
(
self
):
self
.
compare_evoformer
(
torch
.
float32
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_dry_run
(
self
):
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,
])
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
.
update
(
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
)
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
.
globals
.
use_deepspeed_evo_attention
=
True
out_repro
=
model
(
batch
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment