Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
edb7d341
Commit
edb7d341
authored
Apr 26, 2022
by
Gustaf Ahdritz
Browse files
Add memory-efficient attention kernels
parent
816c1843
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
615 additions
and
65 deletions
+615
-65
README.md
README.md
+10
-7
openfold/model/evoformer.py
openfold/model/evoformer.py
+34
-22
openfold/model/msa.py
openfold/model/msa.py
+24
-13
openfold/model/primitives.py
openfold/model/primitives.py
+40
-19
openfold/utils/__init__.py
openfold/utils/__init__.py
+3
-1
openfold/utils/kernel/__init__.py
openfold/utils/kernel/__init__.py
+0
-0
openfold/utils/kernel/attention_core.py
openfold/utils/kernel/attention_core.py
+85
-0
openfold/utils/kernel/csrc/compat.h
openfold/utils/kernel/csrc/compat.h
+11
-0
openfold/utils/kernel/csrc/softmax_cuda.cpp
openfold/utils/kernel/csrc/softmax_cuda.cpp
+30
-0
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
+228
-0
setup.py
setup.py
+66
-3
tests/test_kernels.py
tests/test_kernels.py
+84
-0
No files found.
README.md
View file @
edb7d341
...
@@ -26,6 +26,10 @@ OpenFold is equipped with an implementation of low-memory attention
...
@@ -26,6 +26,10 @@ OpenFold is equipped with an implementation of low-memory attention
(
[
Rabe & Staats 2021
](
https://arxiv.org/pdf/2112.05682.pdf
)
), which
(
[
Rabe & Staats 2021
](
https://arxiv.org/pdf/2112.05682.pdf
)
), which
enables inference on extremely long chains.
enables inference on extremely long chains.
We've modified FastFold's custom CUDA kernels to support in-place attention
during inference and training. These use 4x and 5x less GPU memory than
equivalent FastFold and stock PyTorch implementations, respectively.
We also make available efficient scripts for generating alignments. We've
We also make available efficient scripts for generating alignments. We've
used them to generate millions of alignments that will be released alongside
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
original OpenFold weights, trained from scratch using our code (more on that soon).
...
@@ -57,6 +61,12 @@ To deactivate it, run:
...
@@ -57,6 +61,12 @@ To deactivate it, run:
source
scripts/deactivate_conda_env.sh
source
scripts/deactivate_conda_env.sh
```
```
With the environment active, compile OpenFold's CUDA kernels with
```
bash
python3 setup.py
install
```
To install the HH-suite to
`/usr/bin`
, run
To install the HH-suite to
`/usr/bin`
, run
```
bash
```
bash
...
@@ -138,13 +148,6 @@ to `None` in the config.
...
@@ -138,13 +148,6 @@ to `None` in the config.
### Training
### Training
After activating the OpenFold environment with
`source scripts/activate_conda_env.sh`
, install OpenFold by running
```
bash
python setup.py
install
```
To train the model, you will first need to precompute protein alignments.
To train the model, you will first need to precompute protein alignments.
You have two options. You can use the same procedure DeepMind used by running
You have two options. You can use the same procedure DeepMind used by running
...
...
openfold/model/evoformer.py
View file @
edb7d341
...
@@ -368,6 +368,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -368,6 +368,7 @@ class ExtraMSABlock(nn.Module):
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_memory_efficient_kernel
=
not
_chunk_logits
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
...
@@ -558,11 +559,14 @@ class ExtraMSAStack(nn.Module):
...
@@ -558,11 +559,14 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
block
=
ExtraMSABlock
(
...
@@ -579,7 +583,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -579,7 +583,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -604,23 +608,31 @@ class ExtraMSAStack(nn.Module):
...
@@ -604,23 +608,31 @@ class ExtraMSAStack(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] pair update
[*, N_res, N_res, C_z] pair update
"""
"""
#checkpoint_fn = get_checkpoint_fn()
if
(
not
self
.
chunk_msa_attn
):
#blocks = [
checkpoint_fn
=
get_checkpoint_fn
()
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
blocks
=
[
#]
partial
(
b
,
#def dodo(b, *args):
msa_mask
=
msa_mask
,
# torch.cuda.empty_cache()
pair_mask
=
pair_mask
,
# return b(*args)
chunk_size
=
chunk_size
,
_chunk_logits
=
None
)
for
b
in
self
.
blocks
]
#blocks = [partial(dodo, b) for b in blocks]
def
clear_cache
(
b
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
#for b in blocks:
if
(
self
.
clear_cache_between_blocks
):
# if(torch.is_grad_enabled()):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
...
...
openfold/model/msa.py
View file @
edb7d341
...
@@ -79,20 +79,30 @@ class MSAAttention(nn.Module):
...
@@ -79,20 +79,30 @@ class MSAAttention(nn.Module):
)
)
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
,
)
)
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
use_memory_efficient_kernel
:
bool
,
chunk_size
:
int
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
self
.
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
,
"use_memory_efficient_kernel"
:
use_memory_efficient_kernel
,
},
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
)
def
_prep_inputs
(
self
,
def
_prep_inputs
(
self
,
...
@@ -113,13 +123,6 @@ class MSAAttention(nn.Module):
...
@@ -113,13 +123,6 @@ class MSAAttention(nn.Module):
# [*, N_seq, 1, 1, N_res]
# [*, N_seq, 1, 1, N_res]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if
(
self
.
pair_bias
and
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
layer_norm_z
is
not
None
and
# benefit of
...
@@ -144,6 +147,11 @@ class MSAAttention(nn.Module):
...
@@ -144,6 +147,11 @@ class MSAAttention(nn.Module):
chunk_logits
:
int
,
chunk_logits
:
int
,
checkpoint
:
bool
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
MSA attention with training-time chunking of the softmax computation.
Saves memory in the extra MSA stack. Probably obviated by our fused
attention kernel, which is now used by default.
"""
MSA_DIM
=
-
4
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
def
_get_qkv
(
m
,
z
):
...
@@ -181,6 +189,7 @@ class MSAAttention(nn.Module):
...
@@ -181,6 +189,7 @@ class MSAAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
]
=
None
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -212,12 +221,13 @@ class MSAAttention(nn.Module):
...
@@ -212,12 +221,13 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
else
:
else
:
m
=
self
.
mha
(
m
=
self
.
mha
(
q_x
=
m
,
q_x
=
m
,
kv_x
=
m
,
kv_x
=
m
,
biases
=
biases
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
)
)
return
m
return
m
...
@@ -291,7 +301,8 @@ class MSAColumnAttention(nn.Module):
...
@@ -291,7 +301,8 @@ class MSAColumnAttention(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
...
openfold/model/primitives.py
View file @
edb7d341
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
...
@@ -24,6 +23,7 @@ import torch.nn as nn
...
@@ -24,6 +23,7 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
...
@@ -199,8 +199,9 @@ class LayerNorm(nn.Module):
...
@@ -199,8 +199,9 @@ class LayerNorm(nn.Module):
return
out
return
out
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
softmax
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
softmax
_no_cast
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
"""
Softmax, but without automatic casting to fp32 when the input is of
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
type bfloat16
...
@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
...
@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
#@torch.jit.script
#@torch.jit.script
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
key
=
permute_final_dims
(
key
,
(
1
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
# [*, H, Q, K]
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
a
=
torch
.
matmul
(
query
,
key
)
...
@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
...
@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
for
b
in
biases
:
for
b
in
biases
:
a
+=
b
a
+=
b
a
=
softmax
(
a
,
-
1
)
a
=
softmax
_no_cast
(
a
,
-
1
)
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
a
=
torch
.
matmul
(
a
,
value
)
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
a
=
a
.
transpose
(
-
2
,
-
3
)
return
a
return
a
...
@@ -254,7 +246,8 @@ def _attention_chunked_trainable(
...
@@ -254,7 +246,8 @@ def _attention_chunked_trainable(
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
a
=
_attention
(
q
,
k
,
v
,
bs
)
return
a
o_chunks
=
[]
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
checkpoint_fn
=
get_checkpoint_fn
()
...
@@ -290,6 +283,7 @@ def _attention_chunked_trainable(
...
@@ -290,6 +283,7 @@ def _attention_chunked_trainable(
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_chunk
=
o_chunk
.
transpose
(
-
2
,
-
3
)
o_chunks
.
append
(
o_chunk
)
o_chunks
.
append
(
o_chunk
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
...
@@ -374,6 +368,11 @@ class Attention(nn.Module):
...
@@ -374,6 +368,11 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q/K, C_hidden]
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
q
/=
math
.
sqrt
(
self
.
c_hidden
)
q
/=
math
.
sqrt
(
self
.
c_hidden
)
return
q
,
k
,
v
return
q
,
k
,
v
...
@@ -402,6 +401,7 @@ class Attention(nn.Module):
...
@@ -402,6 +401,7 @@ class Attention(nn.Module):
q_x
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
...
@@ -414,8 +414,15 @@ class Attention(nn.Module):
...
@@ -414,8 +414,15 @@ class Attention(nn.Module):
[*, K, C_k] key data
[*, K, C_k] key data
biases:
biases:
List of biases that broadcast to [*, H, Q, K]
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_lma:
use_lma:
Whether to use low-memory attention
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
q_chunk_size:
Query chunk size (for LMA)
Query chunk size (for LMA)
kv_chunk_size:
kv_chunk_size:
...
@@ -430,18 +437,32 @@ class Attention(nn.Module):
...
@@ -430,18 +437,32 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
"be provided"
)
)
if
(
use_memory_efficient_kernel
and
use_lma
):
raise
ValueError
(
"Choose one of use_memory_efficient_kernel and use_lma"
)
# [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
if
(
use_lma
):
# [*, Q, H, C_hidden]
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_lma
):
biases
=
[
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
...
@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module):
...
@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module):
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
+=
bias
a
=
softmax
(
a
)
a
=
softmax
_no_cast
(
a
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
...
...
openfold/utils/__init__.py
View file @
edb7d341
...
@@ -2,12 +2,14 @@ import os
...
@@ -2,12 +2,14 @@ import os
import
glob
import
glob
import
importlib
as
importlib
import
importlib
as
importlib
from
.
import
kernel
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
]
+
[
"kernel"
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
globals
()[
_m
[
0
]]
=
_m
[
1
]
...
...
openfold/utils/kernel/__init__.py
0 → 100644
View file @
edb7d341
openfold/utils/kernel/attention_core.py
0 → 100644
View file @
edb7d341
import
importlib
from
functools
import
reduce
from
operator
import
mul
import
torch
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
class
AttentionCoreFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
bias_1
=
None
,
bias_2
=
None
):
if
(
bias_1
is
None
and
bias_2
is
not
None
):
raise
ValueError
(
"bias_1 must be specified before bias_2"
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
# [*, H, Q, K]
attention_logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
)
if
(
bias_1
is
not
None
):
attention_logits
+=
bias_1
if
(
bias_2
is
not
None
):
attention_logits
+=
bias_2
attn_core_inplace_cuda
.
forward_
(
attention_logits
,
reduce
(
mul
,
attention_logits
.
shape
[:
-
1
]),
attention_logits
.
shape
[
-
1
],
)
o
=
torch
.
matmul
(
attention_logits
,
v
)
ctx
.
bias_1_shape
=
bias_1
.
shape
if
bias_1
is
not
None
else
None
ctx
.
bias_2_shape
=
bias_2
.
shape
if
bias_2
is
not
None
else
None
ctx
.
save_for_backward
(
q
,
k
,
v
,
attention_logits
)
return
o
@
staticmethod
def
backward
(
ctx
,
grad_output
):
q
,
k
,
v
,
attention_logits
=
ctx
.
saved_tensors
grad_q
=
grad_k
=
grad_v
=
grad_bias_1
=
grad_bias_2
=
None
grad_v
=
torch
.
matmul
(
attention_logits
.
transpose
(
-
1
,
-
2
),
grad_output
)
attn_core_inplace_cuda
.
backward_
(
attention_logits
,
grad_output
.
contiguous
(),
v
.
contiguous
(),
# v is implicitly transposed in the kernel
reduce
(
mul
,
attention_logits
.
shape
[:
-
1
]),
attention_logits
.
shape
[
-
1
],
grad_output
.
shape
[
-
1
],
)
if
(
ctx
.
bias_1_shape
is
not
None
):
grad_bias_1
=
torch
.
sum
(
attention_logits
,
dim
=
tuple
(
i
for
i
,
d
in
enumerate
(
ctx
.
bias_1_shape
)
if
d
==
1
),
keepdim
=
True
,
)
if
(
ctx
.
bias_2_shape
is
not
None
):
grad_bias_2
=
torch
.
sum
(
attention_logits
,
dim
=
tuple
(
i
for
i
,
d
in
enumerate
(
ctx
.
bias_2_shape
)
if
d
==
1
),
keepdim
=
True
,
)
grad_q
=
torch
.
matmul
(
attention_logits
,
k
)
grad_k
=
torch
.
matmul
(
q
.
transpose
(
-
1
,
-
2
),
attention_logits
,
).
transpose
(
-
1
,
-
2
)
return
grad_q
,
grad_k
,
grad_v
,
grad_bias_1
,
grad_bias_2
attention_core
=
AttentionCoreFunction
.
apply
openfold/utils/kernel/csrc/compat.h
0 → 100644
View file @
edb7d341
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
openfold/utils/kernel/csrc/softmax_cuda.cpp
0 → 100644
View file @
edb7d341
#include <torch/extension.h>
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
void
attn_softmax_inplace_forward_
(
at
::
Tensor
input
,
long
long
rows
,
int
cols
);
void
attn_softmax_inplace_backward_
(
at
::
Tensor
output
,
at
::
Tensor
d_ov
,
at
::
Tensor
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_"
,
&
attn_softmax_inplace_forward_
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward_"
,
&
attn_softmax_inplace_backward_
,
"Softmax backward (CUDA)"
);
}
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
0 → 100644
View file @
edb7d341
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__
__device__
float
WarpAllReduceMax
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
=
max
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
));
}
return
val
;
}
__inline__
__device__
float
WarpAllReduceSum
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
+=
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
);
}
return
val
;
}
template
<
typename
T
>
__global__
void
attn_softmax_inplace_
(
T
*
input
,
long
long
rows
,
int
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)(
blockIdx
.
x
*
4
+
threadidx_x
);
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
row_input
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
int
idx
=
lane_id
*
cols_per_thread
+
i
;
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
idx
]);
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
void
attn_softmax_inplace_forward_
(
at
::
Tensor
input
,
long
long
rows
,
int
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
attn_softmax_inplace_
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
);
}
else
{
attn_softmax_inplace_
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
);
}
}
template
<
typename
T
>
__global__
void
attn_softmax_inplace_grad_
(
//__global__ void attn_softmax_inplace_grad_bf16_(
T
*
output
,
T
*
d_ov
,
T
*
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)(
blockIdx
.
x
*
4
+
threadidx_x
);
int
cols_per_thread
=
(
cols_output
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
rows_values
=
cols_output
;
// values are set to the beginning of the current
// rows_values x cols_values leaf matrix
long
long
value_row_offset
=
row_offset
-
row_offset
%
rows_values
;
int
last_y
=
(
cols_output
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols_output
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_output
=
output
+
row_offset
*
cols_output
;
T
*
row_d_ov
=
d_ov
+
row_offset
*
cols_values
;
T
*
row_values
=
values
+
value_row_offset
*
cols_values
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
// Compute a chunk of the output gradient on the fly
int
value_row_idx
=
0
;
int
value_idx
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
T
sum
=
0.
;
#pragma unroll
for
(
int
j
=
0
;
j
<
cols_values
;
j
++
)
{
value_row_idx
=
((
lane_id
*
cols_per_thread
)
+
i
);
value_idx
=
value_row_idx
*
cols_values
+
j
;
sum
+=
row_d_ov
[
j
]
*
row_values
[
value_idx
];
}
dy_buf
[
i
]
=
static_cast
<
float
>
(
sum
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]
);
}
}
}
void
attn_softmax_inplace_backward_
(
at
::
Tensor
output
,
at
::
Tensor
d_ov
,
at
::
Tensor
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
d_ov
);
CHECK_INPUT
(
values
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
attn_softmax_inplace_grad_
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_ov
.
data_ptr
(),
(
float
*
)
values
.
data_ptr
(),
rows
,
cols_output
,
cols_values
);
}
else
{
attn_softmax_inplace_grad_
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_ov
.
data_ptr
(),
(
at
::
BFloat16
*
)
values
.
data_ptr
(),
rows
,
cols_output
,
cols_values
);
}
}
setup.py
View file @
edb7d341
...
@@ -12,8 +12,46 @@
...
@@ -12,8 +12,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
setuptools
import
find_packages
import
os
from
setuptools
import
setup
from
setuptools
import
setup
,
Extension
,
find_packages
import
subprocess
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
version_dependent_macros
=
[
'-DVERSION_GE_1_1'
,
'-DVERSION_GE_1_3'
,
'-DVERSION_GE_1_5'
,
]
extra_cuda_flags
=
[
'-std=c++14'
,
'-maxrregcount=50'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
extra_cuda_flags
+=
cc_flag
setup
(
setup
(
name
=
'openfold'
,
name
=
'openfold'
,
...
@@ -25,7 +63,32 @@ setup(
...
@@ -25,7 +63,32 @@ setup(
url
=
'https://github.com/aqlaboratory/openfold'
,
url
=
'https://github.com/aqlaboratory/openfold'
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
include_package_data
=
True
,
include_package_data
=
True
,
package_data
=
{
""
:
[
"resources/stereo_chemical_props.txt"
]},
package_data
=
{
"openfold"
:
[
'utils/kernel/csrc/*'
],
""
:
[
"resources/stereo_chemical_props.txt"
]
},
ext_modules
=
[
CUDAExtension
(
name
=
"attn_core_inplace_cuda"
,
sources
=
[
"openfold/utils/kernel/csrc/softmax_cuda.cpp"
,
"openfold/utils/kernel/csrc/softmax_cuda_kernel.cu"
,
],
include_dirs
=
[
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'openfold/utils/kernel/csrc/'
)
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
(
[
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
+
extra_cuda_flags
),
}
)],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
install_requires
=
[
install_requires
=
[
'torch'
,
'torch'
,
'deepspeed'
,
'deepspeed'
,
...
...
tests/test_kernels.py
0 → 100644
View file @
edb7d341
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
torch
import
unittest
from
openfold.model.primitives
import
_attention
from
openfold.utils.kernel.attention_core
import
attention_core
from
tests.config
import
consts
class
TestAttentionCore
(
unittest
.
TestCase
):
def
test_attention_core_forward
(
self
):
n_res
=
consts
.
n_res
h
=
consts
.
n_heads_extra_msa
n_seq
=
consts
.
n_extra
c
=
consts
.
c_e
dtype
=
torch
.
float32
q
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
k
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
v
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
[
n_seq
,
n_res
]).
cuda
()
mask_bias
=
(
1e9
*
mask
-
1
)[...,
None
,
None
,
:].
to
(
dtype
)
out_repro
=
attention_core
(
q
,
k
,
v
,
mask_bias
,
None
)
out_gt
=
_attention
(
q
,
k
,
v
,
[
mask_bias
])
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
-
out_gt
))
<
consts
.
eps
)
def
test_attention_core_backward
(
self
):
n_res
=
consts
.
n_res
h
=
consts
.
n_heads_extra_msa
n_seq
=
consts
.
n_extra
c
=
consts
.
c_e
dtype
=
torch
.
float32
q
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
k
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
v
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
[
n_seq
,
n_res
]).
cuda
()
mask_bias
=
(
1e9
*
mask
-
1
)[...,
None
,
None
,
:].
to
(
dtype
)
def
clone
(
t
):
t
=
t
.
clone
()
if
(
t
.
requires_grad
):
t
.
retain_grad
()
return
t
q_repro
=
clone
(
q
)
k_repro
=
clone
(
k
)
v_repro
=
clone
(
v
)
out_repro
=
attention_core
(
q_repro
,
k_repro
,
v_repro
,
mask_bias
,
None
)
loss_repro
=
torch
.
mean
(
out_repro
)
loss_repro
.
backward
()
q_gt
=
clone
(
q
)
k_gt
=
clone
(
k
)
v_gt
=
clone
(
v
)
out_gt
=
_attention
(
q_gt
,
k_gt
,
v_gt
,
[
mask_bias
]
)
loss_gt
=
torch
.
mean
(
out_gt
)
loss_gt
.
backward
()
pairs
=
zip
([
q_repro
,
k_repro
,
v_repro
],
[
q_gt
,
k_gt
,
v_gt
])
for
t_repro
,
t_gt
in
pairs
:
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
t_repro
.
grad
-
t_gt
.
grad
))
<
consts
.
eps
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment