Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
501a9e9c
Unverified
Commit
501a9e9c
authored
Nov 07, 2022
by
oahzxl
Committed by
GitHub
Nov 07, 2022
Browse files
[hotfix] polish flash attention (#1802)
parent
218c75fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
21 deletions
+24
-21
colossalai/kernel/cuda_native/flash_attention.py
colossalai/kernel/cuda_native/flash_attention.py
+20
-17
tests/test_utils/test_flash_attention.py
tests/test_utils/test_flash_attention.py
+4
-4
No files found.
colossalai/kernel/cuda_native/flash_attention.py
View file @
501a9e9c
...
@@ -10,20 +10,6 @@ import subprocess
...
@@ -10,20 +10,6 @@ import subprocess
import
torch
import
torch
try
:
import
triton
import
triton.language
as
tl
HAS_TRITON
=
True
except
ImportError
:
print
(
'please install triton from https://github.com/openai/triton'
)
HAS_TRITON
=
False
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
HAS_FLASH_ATTN
=
True
except
ImportError
:
HAS_FLASH_ATTN
=
False
print
(
'please install flash_attn from https://github.com/HazyResearch/flash-attention'
)
def
triton_check
():
def
triton_check
():
cuda_home
=
os
.
getenv
(
"CUDA_HOME"
,
default
=
"/usr/local/cuda"
)
cuda_home
=
os
.
getenv
(
"CUDA_HOME"
,
default
=
"/usr/local/cuda"
)
...
@@ -38,9 +24,26 @@ def triton_check():
...
@@ -38,9 +24,26 @@ def triton_check():
return
False
return
False
TRITON_AVALIABLE
=
triton_check
()
try
:
import
triton
import
triton.language
as
tl
if
triton_check
():
HAS_TRITON
=
True
else
:
print
(
"triton requires cuda >= 11.4"
)
HAS_TRITON
=
False
except
ImportError
:
print
(
'please install triton from https://github.com/openai/triton'
)
HAS_TRITON
=
False
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
HAS_FLASH_ATTN
=
True
except
ImportError
:
HAS_FLASH_ATTN
=
False
print
(
'please install flash_attn from https://github.com/HazyResearch/flash-attention'
)
if
TRITON
_AVALIABLE
:
if
HAS_
TRITON
:
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
...
@@ -394,7 +397,7 @@ if TRITON_AVALIABLE:
...
@@ -394,7 +397,7 @@ if TRITON_AVALIABLE:
Return:
Return:
out: (batch, nheads, seq, headdim)
out: (batch, nheads, seq, headdim)
"""
"""
if
TRITON
_AVALIABLE
:
if
HAS_
TRITON
:
return
_TritonFlashAttention
.
apply
(
q
,
k
,
v
,
sm_scale
)
return
_TritonFlashAttention
.
apply
(
q
,
k
,
v
,
sm_scale
)
else
:
else
:
raise
RuntimeError
(
"Triton kernel requires CUDA 11.4+!"
)
raise
RuntimeError
(
"Triton kernel requires CUDA 11.4+!"
)
...
...
tests/test_utils/test_flash_attention.py
View file @
501a9e9c
...
@@ -2,7 +2,7 @@ import pytest
...
@@ -2,7 +2,7 @@ import pytest
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_TRITON
,
TRITON_AVALIABLE
from
colossalai.kernel.cuda_native.flash_attention
import
HAS_FLASH_ATTN
,
HAS_TRITON
if
HAS_FLASH_ATTN
:
if
HAS_FLASH_ATTN
:
from
colossalai.kernel.cuda_native.flash_attention
import
flash_attention
from
colossalai.kernel.cuda_native.flash_attention
import
flash_attention
...
@@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
...
@@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return
ref_out
return
ref_out
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"
triton
is not available"
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"
flash
is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
def
test_triton_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
def
test_triton_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
torch
.
manual_seed
(
20
)
...
@@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
...
@@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
ref_dq
,
q
.
grad
=
q
.
grad
.
clone
(),
None
# triton implementation
# triton implementation
if
TRITON
_AVALIABLE
:
if
HAS_
TRITON
:
tri_out
=
triton_flash_attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
=
triton_flash_attention
(
q
,
k
,
v
,
sm_scale
)
tri_out
.
backward
(
dout
)
tri_out
.
backward
(
dout
)
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
tri_dv
,
v
.
grad
=
v
.
grad
.
clone
(),
None
...
@@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
...
@@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise
TypeError
(
"Error type not match!"
)
raise
TypeError
(
"Error type not match!"
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"
triton
is not available"
)
@
pytest
.
mark
.
skipif
(
HAS_FLASH_ATTN
==
False
,
reason
=
"
flash
is not available"
)
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
@
pytest
.
mark
.
parametrize
(
'Z, H, N_CTX, D_HEAD'
,
[(
3
,
2
,
16
,
8
)])
def
test_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
def
test_flash_attention
(
Z
,
H
,
N_CTX
,
D_HEAD
,
dtype
=
torch
.
float16
):
torch
.
manual_seed
(
20
)
torch
.
manual_seed
(
20
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment