Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
077a66dd
Unverified
Commit
077a66dd
authored
Dec 16, 2022
by
アマデウス
Committed by
GitHub
Dec 16, 2022
Browse files
updated attention kernel (#2133)
parent
484fe622
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
1 deletion
+45
-1
colossalai/kernel/cuda_native/flash_attention.py
colossalai/kernel/cuda_native/flash_attention.py
+26
-0
tests/test_utils/test_flash_attention.py
tests/test_utils/test_flash_attention.py
+19
-1
No files found.
colossalai/kernel/cuda_native/flash_attention.py
View file @
077a66dd
...
...
@@ -48,6 +48,13 @@ except ImportError:
HAS_FLASH_ATTN
=
False
print
(
'please install flash_attn from https://github.com/HazyResearch/flash-attention'
)
try
:
from
xformers.ops.fmha
import
memory_efficient_attention
HAS_MEM_EFF_ATTN
=
True
except
ImportError
:
HAS_MEM_EFF_ATTN
=
False
print
(
'please install xformers from https://github.com/facebookresearch/xformers'
)
if
HAS_TRITON
:
@
triton
.
jit
...
...
@@ -497,3 +504,22 @@ if HAS_FLASH_ATTN:
device
=
k
.
device
)
return
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
q_seqlen
,
kv_seqlen
,
dropout_p
,
sm_scale
,
causal
)
if
HAS_MEM_EFF_ATTN
:
from
einops
import
rearrange
from
xformers.ops.fmha
import
LowerTriangularMask
class
MemoryEfficientAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
attention_dropout
:
float
=
0.0
):
super
().
__init__
()
attention_head_size
=
hidden_size
//
num_attention_heads
self
.
scale
=
1
/
attention_head_size
**
0.5
self
.
dropout
=
attention_dropout
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
):
context
=
memory_efficient_attention
(
query
,
key
,
value
,
attention_mask
,
self
.
dropout
,
self
.
scale
)
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
return
context
tests/test_utils/test_flash_attention.py
View file @
077a66dd
...
...
@@ -2,7 +2,7 @@ import pytest
import
torch
from
einops
import
rearrange
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_TRITON
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_MEM_EFF_ATTN
,
HAS_TRITON
if
HAS_FLASH_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
(
...
...
@@ -15,6 +15,9 @@ if HAS_FLASH_ATTN:
if
HAS_TRITON
:
from
colossalai.kernel.cuda_native.flash_attention
import
triton_flash_attention
if
HAS_MEM_EFF_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
LowerTriangularMask
,
MemoryEfficientAttention
def
baseline_attention
(
Z
,
N_CTX
,
H
,
q
,
k
,
v
,
sm_scale
):
M
=
torch
.
tril
(
torch
.
ones
((
N_CTX
,
N_CTX
),
device
=
"cuda"
))
...
...
@@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
out
.
backward
(
dout
)
@
pytest
.
mark
.
skipif
(
HAS_MEM_EFF_ATTN
==
False
,
reason
=
"xformers is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
6
,
8
,
4
,
16
)])
def
test_memory_efficient_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
attn
=
MemoryEfficientAttention
(
N_CTX
*
D_HEAD
,
N_CTX
,
0.1
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
k
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
out
=
attn
(
q
,
k
,
v
,
attention_mask
=
LowerTriangularMask
())
dout
=
torch
.
rand_like
(
out
)
out
.
backward
(
dout
)
if
__name__
==
'__main__'
:
test_flash_attention
(
3
,
4
,
2
,
16
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment