Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9639ea88
Unverified
Commit
9639ea88
authored
Nov 07, 2022
by
oahzxl
Committed by
GitHub
Nov 07, 2022
Browse files
[kernel] more flexible flashatt interface (#1804)
parent
20e255d4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
49 deletions
+121
-49
colossalai/kernel/cuda_native/flash_attention.py
colossalai/kernel/cuda_native/flash_attention.py
+71
-17
tests/test_utils/test_flash_attention.py
tests/test_utils/test_flash_attention.py
+50
-32
No files found.
colossalai/kernel/cuda_native/flash_attention.py
View file @
9639ea88
...
...
@@ -11,7 +11,7 @@ import subprocess
import
torch
def
triton_check
():
def
triton_
cuda_
check
():
cuda_home
=
os
.
getenv
(
"CUDA_HOME"
,
default
=
"/usr/local/cuda"
)
cuda_version
=
subprocess
.
check_output
([
os
.
path
.
join
(
cuda_home
,
"bin/nvcc"
),
"--version"
]).
decode
().
strip
()
cuda_version
=
cuda_version
.
split
(
'release '
)[
1
]
...
...
@@ -27,7 +27,7 @@ def triton_check():
try
:
import
triton
import
triton.language
as
tl
if
triton_check
():
if
triton_
cuda_
check
():
HAS_TRITON
=
True
else
:
print
(
"triton requires cuda >= 11.4"
)
...
...
@@ -36,7 +36,11 @@ except ImportError:
print
(
'please install triton from https://github.com/openai/triton'
)
HAS_TRITON
=
False
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
from
flash_attn.flash_attn_interface
import
(
flash_attn_unpadded_func
,
flash_attn_unpadded_kvpacked_func
,
flash_attn_unpadded_qkvpacked_func
,
)
HAS_FLASH_ATTN
=
True
except
ImportError
:
HAS_FLASH_ATTN
=
False
...
...
@@ -405,12 +409,63 @@ if HAS_TRITON:
if
HAS_FLASH_ATTN
:
def
flash_attention
(
q
,
k
,
v
,
sm_scale
,
batch_size
,
seq_len
,
dropout_p
=
0.
,
causal
=
Tru
e
):
def
flash_attention
_qkv
(
qk
v
,
sm_scale
,
batch_size
,
seq_len
,
dropout_p
=
0.
,
causal
=
Fals
e
):
"""
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
qkv: (batch * seqlen, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
max_s
=
seq_len
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seq_len
,
step
=
seq_len
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
out
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_s
,
dropout_p
,
softmax_scale
=
sm_scale
,
causal
=
causal
)
return
out
def
flash_attention_q_kv
(
q
,
kv
,
sm_scale
,
batch_size
,
q_seqlen
,
kv_seqlen
,
dropout_p
=
0.
,
causal
=
False
):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
kv: (batch * kv_seqlen, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
q_seqlen
,
step
=
q_seqlen
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
kv_seqlen
,
step
=
kv_seqlen
,
dtype
=
torch
.
int32
,
device
=
kv
.
device
)
out
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
q_seqlen
,
kv_seqlen
,
dropout_p
,
sm_scale
,
causal
)
return
out
def
flash_attention_q_k_v
(
q
,
k
,
v
,
sm_scale
,
batch_size
,
q_seqlen
,
kv_seqlen
,
dropout_p
=
0.
,
causal
=
False
):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
k: (batch * kv_seqlen, nheads, headdim)
v: (batch * kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
...
...
@@ -420,16 +475,15 @@ if HAS_FLASH_ATTN:
Return:
out: (total, nheads, headdim).
"""
lengths
=
torch
.
full
((
batch_size
,),
fill_value
=
seq_len
,
device
=
q
.
device
)
cu_seqlens
=
torch
.
zeros
((
batch_size
+
1
,),
device
=
q
.
device
,
dtype
=
torch
.
int32
)
cu_seqlens
[
1
:]
=
lengths
.
cumsum
(
0
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
q_seqlen
,
step
=
q_seqlen
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_kv
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
kv_seqlen
,
step
=
kv_seqlen
,
dtype
=
torch
.
int32
,
device
=
k
.
device
)
return
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
seq
_
len
,
max
_seqlen
_k
=
seq_len
,
dropout_p
=
dropout_p
,
softmax_scale
=
sm_scale
,
causal
=
causal
)
cu_seqlens_q
,
cu_seqlens_k
v
,
q_
seqlen
,
kv
_seqlen
,
dropout_p
,
sm_scale
,
causal
)
tests/test_utils/test_flash_attention.py
View file @
9639ea88
...
...
@@ -5,7 +5,8 @@ from einops import rearrange
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_TRITON
if
HAS_FLASH_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
flash_attention
from
colossalai.kernel.cuda_native.flash_attention
import
(
flash_attention_q_k_v
,
flash_attention_q_kv
,
flash_attention_qkv
)
if
HAS_TRITON
:
from
colossalai.kernel.cuda_native.flash_attention
import
triton_flash_attention
...
...
@@ -22,8 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return
ref_out
@
pytest
.
mark
.
skipif
(
HAS_
FLASH_ATT
N
==
False
,
reason
=
"
flash
is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
@
pytest
.
mark
.
skipif
(
HAS_
TRITO
N
==
False
,
reason
=
"
triton
is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
4
,
2
,
16
)])
def
test_triton_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
...
...
@@ -39,28 +40,20 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
if
HAS_TRITON
:
tri_out
=
triton_flash_attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
else
:
try
:
tri_out
=
flash_attention
(
q
,
k
,
v
,
sm_scale
,
Z
,
N_CTX
)
except
RuntimeError
:
pass
else
:
raise
TypeError
(
"Error type not match!"
)
tri_out
=
triton_flash_attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dk
,
k
.
grad
=
k
.
grad
.
clone
(),
None
tri_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"flash is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
4
,
2
,
16
)])
def
test_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
q
=
torch
.
randn
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0
,
std
=
.
5
).
requires_grad_
()
...
...
@@ -78,15 +71,40 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# flash implementation
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'z h n d -> (z n) h d'
),
[
q
,
k
,
v
])
tri_out
=
flash_attention
(
q
,
k
,
v
,
sm_scale
,
Z
,
N_CTX
)
dout
=
rearrange
(
dout
,
'z h n d -> (z n) h d'
).
detach
()
tri_out
.
backward
(
dout
,
retain_graph
=
True
)
tri_dq
,
tri_dk
,
tri_dv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
k
,
v
),
dout
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
))
for
i
in
range
(
3
):
if
i
==
0
:
tri_out
=
flash_attention_q_k_v
(
q
,
k
,
v
,
sm_scale
,
Z
,
N_CTX
,
N_CTX
,
causal
=
True
)
elif
i
==
1
:
kv
=
torch
.
cat
((
k
.
unsqueeze
(
1
),
v
.
unsqueeze
(
1
)),
dim
=
1
)
tri_out
=
flash_attention_q_kv
(
q
,
kv
,
sm_scale
,
Z
,
N_CTX
,
N_CTX
,
causal
=
True
)
else
:
qkv
=
torch
.
cat
((
q
.
unsqueeze
(
1
),
k
.
unsqueeze
(
1
),
v
.
unsqueeze
(
1
)),
dim
=
1
)
tri_out
=
flash_attention_qkv
(
qkv
,
sm_scale
,
Z
,
N_CTX
,
causal
=
True
)
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
tri_out
.
backward
(
dout
,
retain_graph
=
True
)
if
i
==
0
:
tri_dq
,
tri_dk
,
tri_dv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
k
,
v
),
dout
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
))
elif
i
==
1
:
tri_dq
,
tri_dkv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
q
,
kv
),
dout
)
tri_dk
,
tri_dv
=
torch
.
chunk
(
tri_dkv
,
2
,
dim
=
1
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
,
tri_dk
.
squeeze
(
1
),
tri_dv
.
squeeze
(
1
)))
else
:
tri_dqkv
,
=
torch
.
autograd
.
grad
(
tri_out
,
(
qkv
),
dout
)
tri_dq
,
tri_dk
,
tri_dv
=
torch
.
chunk
(
tri_dqkv
,
3
,
dim
=
1
)
tri_out
,
tri_dq
,
tri_dk
,
tri_dv
=
map
(
lambda
x
:
rearrange
(
x
,
'(z n) h d -> z h n d'
,
z
=
Z
),
(
tri_out
,
tri_dq
.
squeeze
(
1
),
tri_dk
.
squeeze
(
1
),
tri_dv
.
squeeze
(
1
)))
# compare
assert
torch
.
allclose
(
ref_out
,
tri_out
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dv
,
tri_dv
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dk
,
tri_dk
,
atol
=
1e-3
)
assert
torch
.
allclose
(
ref_dq
,
tri_dq
,
atol
=
1e-3
)
if
__name__
==
'__main__'
:
test_flash_attention
(
3
,
4
,
2
,
16
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment