Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5a61cb77
"docs/source/_config.py" did not exist on "d2a5247a1f919e1dfbf2a15f2fda21a1ea11a116"
Commit
5a61cb77
authored
Jun 01, 2022
by
Tri Dao
Browse files
Rename src -> flash_attn
parent
c41479d6
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
75 additions
and
41 deletions
+75
-41
README.md
README.md
+0
-1
benchmarks/benchmark_flash_attention.py
benchmarks/benchmark_flash_attention.py
+2
-2
benchmarks/utils.py
benchmarks/utils.py
+26
-12
flash_attn/bert_padding.py
flash_attn/bert_padding.py
+0
-0
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+3
-3
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+0
-0
flash_attn/flash_blocksparse_attention.py
flash_attn/flash_blocksparse_attention.py
+3
-3
flash_attn/flash_blocksparse_attn_interface.py
flash_attn/flash_blocksparse_attn_interface.py
+0
-0
flash_attn/rotary.py
flash_attn/rotary.py
+0
-0
setup.py
setup.py
+41
-20
No files found.
README.md
View file @
5a61cb77
...
@@ -11,7 +11,6 @@ Paper: https://arxiv.org/abs/2205.14135
...
@@ -11,7 +11,6 @@ Paper: https://arxiv.org/abs/2205.14135
To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
```
```
cd csrc/flash_attn
python setup.py install
python setup.py install
```
```
...
...
benchmarks/benchmark_flash_attention.py
View file @
5a61cb77
...
@@ -7,8 +7,8 @@ import torch.nn.functional as F
...
@@ -7,8 +7,8 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
benchmarks.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
benchmarks.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
src
.bert_padding
import
unpad_input
,
pad_input
from
flash_attn
.bert_padding
import
unpad_input
,
pad_input
from
src
.flash_attn_interface
import
flash_attn_func
from
flash_attn
.flash_attn_interface
import
flash_attn_func
def
attention_ref
(
qkv
,
attn_mask
,
dropout_p
,
upcast
=
False
,
causal
=
False
):
def
attention_ref
(
qkv
,
attn_mask
,
dropout_p
,
upcast
=
False
,
causal
=
False
):
...
...
benchmarks/utils.py
View file @
5a61cb77
...
@@ -86,20 +86,34 @@ def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **k
...
@@ -86,20 +86,34 @@ def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **k
)
)
def
pytorch_profiler
(
fn
,
*
inputs
,
repeats
=
10
):
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
verbose
=
True
):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if
backward
:
g
=
torch
.
randn_like
(
fn
(
*
inputs
))
for
_
in
range
(
10
):
# Warm up
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
with
torch
.
profiler
.
profile
(
with
torch
.
profiler
.
profile
(
activities
=
[
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,],
torch
.
profiler
.
ProfilerActivity
.
CPU
,
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
,],
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
record_shapes
=
True
,
record_shapes
=
True
,
profile_memory
=
True
,
#
profile_memory=True,
with_stack
=
True
,
with_stack
=
True
,
)
as
p
:
)
as
prof
:
# benchmark_forward(repeats, fn, *inputs)
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
fn
(
*
inputs
)
if
backward
:
print
(
p
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=-
1
))
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
if
verbose
:
print
(
prof
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=
50
))
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_memory
(
fn
,
*
inputs
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
...
...
src
/bert_padding.py
→
flash_attn
/bert_padding.py
View file @
5a61cb77
File moved
src
/flash_attention.py
→
flash_attn
/flash_attention.py
View file @
5a61cb77
...
@@ -4,9 +4,9 @@ import torch.nn as nn
...
@@ -4,9 +4,9 @@ import torch.nn as nn
from
einops
import
rearrange
from
einops
import
rearrange
from
src
.rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
flash_attn
.rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
src
.flash_attn_interface
import
flash_attn_func
from
flash_attn
.flash_attn_interface
import
flash_attn_func
from
src
.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn
.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
FlashAttention
(
nn
.
Module
):
class
FlashAttention
(
nn
.
Module
):
...
...
src
/flash_attn_interface.py
→
flash_attn
/flash_attn_interface.py
View file @
5a61cb77
File moved
src
/flash_blocksparse_attention.py
→
flash_attn
/flash_blocksparse_attention.py
View file @
5a61cb77
...
@@ -6,9 +6,9 @@ from einops import rearrange
...
@@ -6,9 +6,9 @@ from einops import rearrange
import
hydra
import
hydra
from
src
.flash_blocksparse_attn_interface
import
flash_blocksparse_attn_func
from
flash_attn
.flash_blocksparse_attn_interface
import
flash_blocksparse_attn_func
from
src
.flash_blocksparse_attn_interface
import
convert_blockmask
from
flash_attn
.flash_blocksparse_attn_interface
import
convert_blockmask
from
src
.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn
.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
FlashBlocksparseAttention
(
nn
.
Module
):
class
FlashBlocksparseAttention
(
nn
.
Module
):
...
...
src
/flash_blocksparse_attn_interface.py
→
flash_attn
/flash_blocksparse_attn_interface.py
View file @
5a61cb77
File moved
src
/rotary.py
→
flash_attn
/rotary.py
View file @
5a61cb77
File moved
csrc/flash_attn/
setup.py
→
setup.py
View file @
5a61cb77
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
import
subprocess
import
sys
import
sys
import
warnings
import
warnings
import
os
import
os
from
pathlib
import
Path
from
setuptools
import
setup
,
find_packages
import
subprocess
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
with
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
)
as
fh
:
long_description
=
fh
.
read
()
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
@@ -66,8 +73,8 @@ if not torch.cuda.is_available():
...
@@ -66,8 +73,8 @@ if not torch.cuda.is_available():
print
(
print
(
"
\n
Warning: Torch did not find available GPUs on this system.
\n
"
,
"
\n
Warning: Torch did not find available GPUs on this system.
\n
"
,
"If your intention is to cross-compile, this is not an error.
\n
"
"If your intention is to cross-compile, this is not an error.
\n
"
"By default,
Apex will
cross-compile for
Pascal
(compute capabilit
ies 6.0, 6.1, 6.2),
\n
"
"By default,
We
cross-compile for
Volta
(compute capabilit
y 7.0),
"
"
Volta (compute capability 7.0),
Turing (compute capability 7.5),
\n
"
"Turing (compute capability 7.5),
\n
"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
...
@@ -75,11 +82,11 @@ if not torch.cuda.is_available():
...
@@ -75,11 +82,11 @@ if not torch.cuda.is_available():
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
==
11
:
if
int
(
bare_metal_major
)
==
11
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"
6.0;6.1;6.2;
7.0;7.5;8.0"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5;8.0"
if
int
(
bare_metal_minor
)
>
0
:
if
int
(
bare_metal_minor
)
>
0
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"
6.0;6.1;6.2;
7.0;7.5;8.0;8.6"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5;8.0;8.6"
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"
6.0;6.1;6.2;
7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
...
@@ -95,7 +102,7 @@ torch_dir = torch.__path__[0]
...
@@ -95,7 +102,7 @@ torch_dir = torch.__path__[0]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
raise_if_cuda_home_none
(
"
--
flashattn"
)
raise_if_cuda_home_none
(
"flash
_
attn"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
...
@@ -108,11 +115,11 @@ ext_modules.append(
...
@@ -108,11 +115,11 @@ ext_modules.append(
CUDAExtension
(
CUDAExtension
(
name
=
"flash_attn_cuda"
,
name
=
"flash_attn_cuda"
,
sources
=
[
sources
=
[
"fmha_api.cpp"
,
"
csrc/flash_attn/
fmha_api.cpp"
,
"src/fmha_fprop_fp16_kernel.sm80.cu"
,
"
csrc/flash_attn/
src/fmha_fprop_fp16_kernel.sm80.cu"
,
"src/fmha_dgrad_fp16_kernel_loop.sm80.cu"
,
"
csrc/flash_attn/
src/fmha_dgrad_fp16_kernel_loop.sm80.cu"
,
"src/fmha_block_fprop_fp16_kernel.sm80.cu"
,
"
csrc/flash_attn/
src/fmha_block_fprop_fp16_kernel.sm80.cu"
,
"src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
"
csrc/flash_attn/
src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
"cxx"
:
[
"-O3"
]
+
generator_flag
,
...
@@ -132,16 +139,30 @@ ext_modules.append(
...
@@ -132,16 +139,30 @@ ext_modules.append(
),
),
},
},
include_dirs
=
[
include_dirs
=
[
this_dir
,
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
,
os
.
path
.
join
(
this_dir
,
"
src
"
)
,
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
/
'
src
'
,
],
],
)
)
)
)
setup
(
setup
(
name
=
"flash_attn
_cuda
"
,
name
=
"flash_attn"
,
version
=
"0.1"
,
version
=
"0.1"
,
description
=
"Flash Attention"
,
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
),
author
=
"Tri Dao"
,
author_email
=
"trid@stanford.edu"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/HazyResearch/flash-attention"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: Apache 2.0"
,
"Operating System :: Linux"
,
],
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{},
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{},
python_requires
=
">=3.7"
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment