Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7bc0afc9
Commit
7bc0afc9
authored
Mar 17, 2023
by
zbian
Committed by
アマデウス
Mar 20, 2023
Browse files
updated flash attention usage
parent
085e7f4e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
307 additions
and
170 deletions
+307
-170
LICENSE
LICENSE
+70
-0
colossalai/kernel/cuda_native/flash_attention.py
colossalai/kernel/cuda_native/flash_attention.py
+156
-51
tests/test_utils/test_flash_attention.py
tests/test_utils/test_flash_attention.py
+81
-119
No files found.
LICENSE
View file @
7bc0afc9
...
...
@@ -326,3 +326,73 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR Flash Attention ----------------
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR Facebook xFormers ----------------
From xFormers:
Copyright (c) Facebook, Inc. and its affiliates
===
BSD 3-Clause License
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
colossalai/kernel/cuda_native/flash_attention.py
View file @
7bc0afc9
"""
The triton-based flash attention implementation is copied from the OpenAI/triton repository
You can find the repository in Triton https://github.com/openai/triton
You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
Reference:
1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
import
math
...
...
@@ -15,6 +9,159 @@ import subprocess
import
torch
try
:
from
xformers.ops.fmha
import
memory_efficient_attention
HAS_MEM_EFF_ATTN
=
True
except
ImportError
:
HAS_MEM_EFF_ATTN
=
False
print
(
'please install xformers from https://github.com/facebookresearch/xformers'
)
if
HAS_MEM_EFF_ATTN
:
from
typing
import
Optional
from
einops
import
rearrange
from
xformers.ops.fmha
import
MemoryEfficientAttentionCutlassOp
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
,
LowerTriangularMask
,
LowerTriangularMaskWithTensorBias
from
.scaled_softmax
import
AttnMaskType
allow_alibi
=
True
for
op
in
MemoryEfficientAttentionCutlassOp
:
allow_alibi
=
allow_alibi
&
(
LowerTriangularMaskWithTensorBias
in
op
.
SUPPORTED_ATTN_BIAS_TYPES
)
class
Unpad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
):
ctx
.
save_for_backward
(
indices
)
# [b, s, ...]
assert
tensor
.
ndim
>=
3
ctx
.
bsz
=
tensor
.
shape
[
0
]
out
=
rearrange
(
tensor
,
'b s ... -> (b s) ...'
)
ctx
.
shape
=
out
.
shape
# [1, ntokens, ...]
return
out
[
indices
].
unsqueeze
(
0
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
# [b*s, ...]
grad
=
torch
.
zeros
(
ctx
.
shape
,
dtype
=
grad_output
.
dtype
,
device
=
grad_output
.
device
)
grad
[
indices
]
=
grad_output
.
squeeze
(
0
)
grad
=
rearrange
(
grad
,
'(b s) ... -> b s ...'
,
b
=
ctx
.
bsz
)
# [b, s, ...]
return
grad
,
None
class
Repad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
,
seq_len
:
int
):
ctx
.
save_for_backward
(
indices
)
# [ntokens, ...]
tensor
=
tensor
.
squeeze
(
0
)
out
=
torch
.
zeros
((
batch_size
*
seq_len
,
*
tensor
.
shape
[
1
:]),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
# [b*s, ...]
out
[
indices
]
=
tensor
# [b, s, ...]
out
=
rearrange
(
out
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
# [b*s, ...]
grad_output
=
rearrange
(
grad_output
,
'b s ... -> (b s) ...'
)
grad
=
grad_output
[
indices
]
# [1, ntokens, ...]
return
grad
.
unsqueeze
(
0
),
None
,
None
,
None
class
ColoAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
dropout
:
float
=
0.0
):
super
().
__init__
()
assert
embed_dim
%
num_heads
==
0
,
\
f
"the embed dim (
{
embed_dim
}
) is not divisible by the number of attention heads (
{
num_heads
}
)."
self
.
scale
=
1
/
math
.
sqrt
(
embed_dim
//
num_heads
)
self
.
dropout
=
dropout
@
staticmethod
def
get_seq_info_from_mask
(
attn_mask
:
torch
.
Tensor
):
indices
=
torch
.
nonzero
(
attn_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
seqlens
=
attn_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
).
flatten
().
tolist
()
return
indices
,
seqlens
@
staticmethod
def
unpad
(
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
Unpad
.
apply
(
tensor
,
indices
)
@
staticmethod
def
repad
(
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
,
seq_len
:
int
)
->
torch
.
Tensor
:
return
Repad
.
apply
(
tensor
,
indices
,
batch_size
,
seq_len
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask_type
:
Optional
[
AttnMaskType
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
batch_size
,
tgt_len
,
src_len
=
query
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
]
attn_bias
=
None
if
attn_mask_type
==
AttnMaskType
.
padding
:
# bert style
assert
attn_mask
is
not
None
,
\
f
"attention mask
{
attn_mask
}
is not valid for attention mask type
{
attn_mask_type
}
."
assert
attn_mask
.
dim
()
==
2
,
\
"attention mask is supposed to have shape (batch_size, seq_len), "
+
\
f
"but got
{
attn_mask
.
dim
()
}
dimensions."
if
tgt_len
==
src_len
:
q_indices
,
q_seqlen
=
self
.
get_seq_info_from_mask
(
attn_mask
)
kv_seqlen
=
None
if
batch_size
>
1
:
query
,
key
,
value
=
self
.
unpad
(
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
),
q_indices
).
unbind
(
dim
=
2
)
else
:
q_indices
=
torch
.
arange
(
batch_size
*
tgt_len
,
dtype
=
torch
.
int32
,
device
=
query
.
device
)
q_seqlen
=
torch
.
LongTensor
([
tgt_len
]
*
batch_size
,
device
=
query
.
device
)
kv_indices
,
kv_seqlen
=
self
.
get_seq_info_from_mask
(
attn_mask
)
if
batch_size
>
1
:
query
=
rearrange
(
query
,
"b s ... -> c (b s) ..."
,
c
=
1
)
key
,
value
=
self
.
unpad
(
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
),
kv_indices
).
unbind
(
dim
=
2
)
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
,
kv_seqlen
)
elif
attn_mask_type
==
AttnMaskType
.
causal
:
# gpt style
attn_bias
=
LowerTriangularMask
()
if
bias
is
not
None
:
# alibi / relative position emebedding
assert
allow_alibi
,
"flash attention with bias is not supported in this system."
assert
attn_mask_type
==
AttnMaskType
.
causal
,
\
"attention with bias is only supported for causal attention so far."
attn_bias
=
attn_bias
.
add_bias
(
bias
)
out
=
memory_efficient_attention
(
query
,
key
,
value
,
attn_bias
=
attn_bias
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
if
attn_mask_type
==
AttnMaskType
.
padding
and
batch_size
>
1
:
out
=
self
.
repad
(
out
,
q_indices
,
batch_size
,
tgt_len
)
out
=
rearrange
(
out
,
'b s h d -> b s (h d)'
)
return
out
##########################################################################
# the flash attention functions below that are copied
# from the OpenAI/triton repository will be deprecated
# You can find the repository in Triton https://github.com/openai/triton
# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
# Reference:
# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
def
triton_cuda_check
():
cuda_home
=
os
.
getenv
(
"CUDA_HOME"
,
default
=
"/usr/local/cuda"
)
...
...
@@ -52,13 +199,6 @@ except ImportError:
HAS_FLASH_ATTN
=
False
print
(
'please install flash_attn from https://github.com/HazyResearch/flash-attention'
)
try
:
from
xformers.ops.fmha
import
memory_efficient_attention
HAS_MEM_EFF_ATTN
=
True
except
ImportError
:
HAS_MEM_EFF_ATTN
=
False
print
(
'please install xformers from https://github.com/facebookresearch/xformers'
)
if
HAS_TRITON
:
# the following functions are adapted from the OpenAI Triton tutorial
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
...
...
@@ -422,25 +562,6 @@ if HAS_TRITON:
if
HAS_FLASH_ATTN
:
from
einops
import
rearrange
class
MaskedFlashAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_attention_heads
:
int
,
attention_head_size
:
int
,
attention_dropout
:
float
)
->
None
:
super
().
__init__
()
self
.
num_attention_heads
=
num_attention_heads
self
.
attention_head_size
=
attention_head_size
self
.
attention_func
=
FlashAttention
(
softmax_scale
=
math
.
sqrt
(
attention_head_size
),
attention_dropout
=
attention_dropout
)
def
forward
(
self
,
query_key_value
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal
=
False
):
if
attention_mask
.
dtype
is
not
torch
.
bool
:
attention_mask
=
attention_mask
.
bool
()
qkv
=
rearrange
(
query_key_value
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_attention_heads
)
context
,
_
=
self
.
attention_func
(
qkv
,
key_padding_mask
=
attention_mask
,
causal
=
causal
)
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
return
context
def
flash_attention_qkv
(
qkv
,
sm_scale
,
batch_size
,
seq_len
,
dropout_p
=
0.
,
causal
=
False
):
"""
Arguments:
...
...
@@ -511,20 +632,4 @@ if HAS_FLASH_ATTN:
causal
)
if
HAS_MEM_EFF_ATTN
:
from
einops
import
rearrange
from
xformers.ops.fmha
import
LowerTriangularMask
class
MemoryEfficientAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
attention_dropout
:
float
=
0.0
):
super
().
__init__
()
attention_head_size
=
hidden_size
//
num_attention_heads
self
.
scale
=
1
/
attention_head_size
**
0.5
self
.
dropout
=
attention_dropout
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
):
context
=
memory_efficient_attention
(
query
,
key
,
value
,
attention_mask
,
self
.
dropout
,
self
.
scale
)
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
return
context
##########################################################################
tests/test_utils/test_flash_attention.py
View file @
7bc0afc9
import
random
import
pytest
import
torch
from
einops
import
rearrange
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_MEM_EFF_ATTN
,
HAS_TRITON
if
HAS_FLASH_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
(
MaskedFlashAttention
,
flash_attention_q_k_v
,
flash_attention_q_kv
,
flash_attention_qkv
,
)
if
HAS_TRITON
:
from
colossalai.kernel.cuda_native.flash_attention
import
triton_flash_attention
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_MEM_EFF_ATTN
if
HAS_MEM_EFF_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
LowerTriangularMask
,
MemoryEfficient
Attention
from
colossalai.kernel.cuda_native.flash_attention
import
AttnMaskType
,
Colo
Attention
def
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
):
...
...
@@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return
ref_out
@
pytest
.
mark
.
skipif
(
HAS_TRITON
==
False
,
reason
=
"triton is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
4
,
2
,
16
)])
def
test_triton_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
ref_out
=
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
tri_out
=
triton_flash_attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"flash is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
4
,
2
,
16
)])
def
test_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
randn
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
randn
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
randn
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
sm_scale
=
0.3
dout
=
torch
.
randn_like
(
q
)
# reference implementation
ref_out
=
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
)
ref_out
.
backward
(
dout
)
ref_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
ref_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# flash implementation
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'z h n d -> (z n) h d'
),
[
q
,
k
,
v
])
dout
=
rearrange
(
dout
,
'z h n d -> (z n) h d'
).
detach
()
for
i
in
range
(
3
):
if
i
==
0
:
tri_out
=
flash_attention_q_k_v
(
q
,
k
,
v
,
sm_scale
,
Z
,
N_CTX
,
N_CTX
,
causal
=
True
)
elif
i
==
1
:
kv
=
torch
.
cat
((
k
.
unsqueeze
(
1
),
v
.
unsqueeze
(
1
)),
dim
=
1
)
tri_out
=
flash_attention_q_kv
(
q
,
kv
,
sm_scale
,
Z
,
N_CTX
,
N_CTX
,
causal
=
True
)
else
:
qkv
=
torch
.
cat
((
q
.
unsqueeze
(
1
),
k
.
unsqueeze
(
1
),
v
.
unsqueeze
(
1
)),
dim
=
1
)
tri_out
=
flash_attention_qkv
(
qkv
,
sm_scale
,
Z
,
N_CTX
,
causal
=
True
)
tri_out
.
backward
(
dout
,
retain_graph
=
True
)
if
i
==
0
:
tri_dq
,
tri_dk
,
tri_dv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
k
,
v
),
dout
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
))
elif
i
==
1
:
tri_dq
,
tri_dkv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
kv
),
dout
)
tri_dk
,
tri_dv
=
torch
.
chunk
(
tri_dkv
,
2
,
dim
=
1
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
.
squeeze
(
1
),
tri_dv
.
squeeze
(
1
)))
else
:
tri_dqkv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
qkv
),
dout
)
tri_dq
,
tri_dk
,
tri_dv
=
torch
.
chunk
(
tri_dqkv
,
3
,
dim
=
1
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
.
squeeze
(
1
),
tri_dk
.
squeeze
(
1
),
tri_dv
.
squeeze
(
1
)))
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"flash is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
4
,
2
,
16
)])
def
test_masked_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
attn
=
MaskedFlashAttention
(
N_CTX
,
D_HEAD
,
0.1
)
qkv
=
torch
.
randn
((
Z
,
H
,
3
*
N_CTX
*
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
attention_mask
=
torch
.
randint
(
2
,
(
Z
,
H
)).
cuda
().
bool
()
out
=
attn
(
qkv
,
attention_mask
)
dout
=
torch
.
rand_like
(
out
)
out
.
backward
(
dout
)
@
pytest
.
mark
.
skipif
(
HAS_MEM_EFF_ATTN
==
False
,
reason
=
"xformers is not available"
)
@
pytest
.
mark
.
parametrize
(
'B, S, H, D_HEAD'
,
[(
6
,
8
,
4
,
16
)])
def
test_attention_gpt
(
B
,
S
,
H
,
D_HEAD
,
dtype
=
torch
.
float16
):
D
=
H
*
D_HEAD
c_attn
=
torch
.
nn
.
Linear
(
D
,
3
*
D
,
dtype
=
dtype
,
device
=
"cuda"
)
attn
=
ColoAttention
(
D
,
H
,
dropout
=
0.1
)
x
=
torch
.
randn
((
B
,
S
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
qkv
=
c_attn
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b s (n h d) -> n b s h d'
,
n
=
3
,
h
=
H
)
y
=
attn
(
q
,
k
,
v
,
attn_mask_type
=
AttnMaskType
.
causal
)
assert
list
(
y
.
shape
)
==
[
B
,
S
,
D
]
dy
=
torch
.
rand_like
(
y
)
y
.
backward
(
dy
)
@
pytest
.
mark
.
skipif
(
HAS_MEM_EFF_ATTN
==
False
,
reason
=
"xformers is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
6
,
8
,
4
,
16
)])
def
test_memory_efficient_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
attn
=
MemoryEfficientAttention
(
N_CTX
*
D_HEAD
,
N_CTX
,
0.1
)
@
pytest
.
mark
.
parametrize
(
'B, S, H, D_HEAD'
,
[(
6
,
8
,
4
,
16
)])
def
test_attention_bert
(
B
,
S
,
H
,
D_HEAD
,
dtype
=
torch
.
float16
):
D
=
H
*
D_HEAD
c_attn
=
torch
.
nn
.
Linear
(
D
,
3
*
D
,
dtype
=
dtype
,
device
=
"cuda"
)
attn
=
ColoAttention
(
D
,
H
,
dropout
=
0.1
)
x
=
torch
.
randn
((
B
,
S
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
# attention mask of shape [B, S] with zero padding to max length S
mask
=
[
torch
.
ones
(
S
-
i
,
dtype
=
dtype
,
device
=
"cuda"
)
for
i
in
range
(
B
)]
mask
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
mask
,
batch_first
=
True
)
qkv
=
c_attn
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b s (n h d) -> b s n h d'
,
n
=
3
,
h
=
H
).
unbind
(
dim
=
2
)
y
=
attn
(
q
,
k
,
v
,
attn_mask
=
mask
,
attn_mask_type
=
AttnMaskType
.
padding
)
assert
list
(
y
.
shape
)
==
[
B
,
S
,
D
]
dy
=
torch
.
rand_like
(
y
)
y
.
backward
(
dy
)
@
pytest
.
mark
.
skipif
(
HAS_MEM_EFF_ATTN
==
False
,
reason
=
"xformers is not available"
)
@
pytest
.
mark
.
parametrize
(
'B, S, H, D_HEAD'
,
[(
6
,
8
,
4
,
16
)])
def
test_attention_no_mask
(
B
,
S
,
H
,
D_HEAD
,
dtype
=
torch
.
float16
):
D
=
H
*
D_HEAD
c_attn
=
torch
.
nn
.
Linear
(
D
,
3
*
D
,
dtype
=
dtype
,
device
=
"cuda"
)
attn
=
ColoAttention
(
D
,
H
,
dropout
=
0.1
)
x
=
torch
.
randn
((
B
,
S
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
qkv
=
c_attn
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b s (n h d) -> b s n h d'
,
n
=
3
,
h
=
H
).
unbind
(
dim
=
2
)
y
=
attn
(
q
,
k
,
v
)
assert
list
(
y
.
shape
)
==
[
B
,
S
,
D
]
dy
=
torch
.
rand_like
(
y
)
y
.
backward
(
dy
)
@
pytest
.
mark
.
skipif
(
HAS_MEM_EFF_ATTN
==
False
,
reason
=
"xformers is not available"
)
@
pytest
.
mark
.
parametrize
(
'B, S, T, H, D_HEAD'
,
[(
6
,
24
,
8
,
4
,
16
)])
def
test_cross_attention
(
B
,
S
,
T
,
H
,
D_HEAD
,
dtype
=
torch
.
float16
):
D
=
H
*
D_HEAD
q_attn
=
torch
.
nn
.
Linear
(
D
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
kv_attn
=
torch
.
nn
.
Linear
(
D
,
2
*
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
attn
=
ColoAttention
(
D
,
H
,
dropout
=
0.1
)
out
=
attn
(
q
,
k
,
v
,
attention_mask
=
LowerTriangularMask
())
src
=
torch
.
randn
((
B
,
S
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
tgt
=
torch
.
randn
((
B
,
T
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
dout
=
torch
.
rand_like
(
out
)
out
.
backward
(
dout
)
q
=
q_attn
(
tgt
)
kv
=
kv_attn
(
src
)
q
=
rearrange
(
q
,
'b s (h d) -> b s h d'
,
h
=
H
)
k
,
v
=
rearrange
(
kv
,
'b s (n h d) -> b s n h d'
,
n
=
2
,
h
=
H
).
unbind
(
dim
=
2
)
y
=
attn
(
q
,
k
,
v
,
attn_mask_type
=
AttnMaskType
.
causal
)
assert
list
(
y
.
shape
)
==
[
B
,
T
,
D
]
if
__name__
==
'__main__'
:
test_flash_attention
(
3
,
4
,
2
,
16
)
dy
=
torch
.
rand_like
(
y
)
y
.
backward
(
dy
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment